├── src ├── metrics │ ├── __init__.py │ └── eval_datasets.py ├── dataset_readers │ ├── __init__.py │ ├── dataset_wrappers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── squad.py │ │ ├── boolq.py │ │ ├── commonsense_qa.py │ │ ├── qnli.py │ │ ├── sst5.py │ │ ├── ag_news.py │ │ ├── trec.py │ │ ├── sst2.py │ │ ├── rte.py │ │ ├── snli.py │ │ └── mnli.py │ ├── prerank_dsr.py │ ├── base_dsr.py │ ├── ppl_inference_cls_dsr.py │ ├── retriever_dsr.py │ └── inference_dsr.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── app.py │ ├── cache_util.py │ ├── model_util.py │ ├── collators.py │ ├── calculate.py │ ├── misc.py │ └── dpp_map.py ├── datasets │ ├── __init__.py │ ├── labels.py │ └── instructions.py └── models │ └── model.py ├── requirements.txt ├── configs ├── ppl_inferencer.yaml ├── inferencer.yaml ├── prerank.yaml └── retriever.yaml ├── scripts ├── sst2 │ ├── run_prompting.sh │ ├── run_random.sh │ ├── run_topk.sh │ ├── run_local_e.sh │ └── run_mdl.sh └── run_mdl.sh ├── .gitignore ├── README.md ├── inferencer.py ├── env.yaml ├── ppl_inferencer.py ├── retriever.py └── prerank.py /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dataset_readers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | hydra-core 3 | datasets 4 | accelerate 5 | nltk 6 | dppy 7 | faiss-gpu 8 | spacy 9 | sentence_transformers -------------------------------------------------------------------------------- /src/utils/app.py: -------------------------------------------------------------------------------- 1 | class App: 2 | def __init__(self): 3 | self.functions = {} 4 | 5 | def add(self, key): 6 | def adder(func): 7 | self.functions[key] = func 8 | return func 9 | 10 | return adder 11 | 12 | def __getitem__(self, __name: str): 13 | return self.functions[__name] 14 | -------------------------------------------------------------------------------- /configs/ppl_inferencer.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | chdir: false 4 | batch_size: 48 5 | model_name: "gpt2-xl" 6 | rand_seed: 1 7 | output_file: ??? 8 | task_name: sst2 9 | window: 10 10 | span: false 11 | n_tokens: 700 12 | instruction_template: 1 13 | overwrite: true 14 | calibrate: false 15 | reverse_label: false 16 | prior_no: 1 17 | dataset_reader: 18 | _target_: src.dataset_readers.ppl_inference_cls_dsr.PPLCLSInferenceDatasetReader 19 | dataset_path: ??? 20 | task_name: ${task_name} 21 | model_name: ${model_name} 22 | index_split: "train" 23 | n_tokens: 700 24 | index_data_path: null 25 | model: 26 | _target_: transformers.AutoModelForCausalLM.from_pretrained 27 | pretrained_model_name_or_path: ${model_name} 28 | -------------------------------------------------------------------------------- /configs/inferencer.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | chdir: false 4 | model_name: "gpt2-large" 5 | task_name: ??? 6 | output_file: ??? 7 | batch_size: 48 8 | dataset_reader: 9 | _target_: src.dataset_readers.inference_dsr.InferenceDatasetReader 10 | dataset_path: ??? 11 | task_name: ${task_name} 12 | model_name: ${model_name} 13 | n_tokens: 700 14 | index_reader: ${index_reader} 15 | index_reader: 16 | _target_: src.dataset_readers.index_dsr.IndexDatasetReader 17 | task_name: ${task_name} 18 | model_name: ${model_name} 19 | field: ALL 20 | dataset_split: train 21 | dataset_path: null 22 | model: 23 | _target_: transformers.AutoModelForCausalLM.from_pretrained 24 | pretrained_model_name_or_path: ${model_name} 25 | 26 | -------------------------------------------------------------------------------- /configs/prerank.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | chdir: false 4 | output_file: ??? 5 | rand_seed: 1 6 | num_candidates: 1 7 | num_ice: 1 8 | dpp_sampling: false 9 | dpp_topk: 100 10 | scale_factor: null 11 | rerank: false 12 | batch_size: 64 13 | cuda_device: cuda:0 14 | overwrite: true 15 | method: 'topk' 16 | vote_k_k: 1 17 | model_name: 'gpt2-xl' 18 | retriever_model: 'all-mpnet-base-v2' 19 | task_name: sst2 20 | index_file: ??? 21 | emb_field: X 22 | dataset_reader: 23 | task_name: sst2 24 | model_name: ${model_name} 25 | field: ${emb_field} 26 | dataset_split: validation 27 | dataset_path: null 28 | 29 | index_reader: 30 | task_name: sst2 31 | model_name: ${model_name} 32 | field: ${emb_field} 33 | dataset_split: train 34 | dataset_path: null 35 | model: 36 | _target_: transformers.AutoModelForCausalLM.from_pretrained 37 | pretrained_model_name_or_path: ${model_name} 38 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pathlib 3 | import types 4 | from os.path import dirname, isfile, join 5 | 6 | modules = {} 7 | modules_list = glob.glob(join(dirname(__file__), "*.py")) 8 | for path in modules_list: 9 | if isfile(path) and not path.endswith('__init__.py') and not path.endswith('__main__.py'): 10 | mod_name = pathlib.Path(path).name[:-3] 11 | module = types.ModuleType(mod_name) 12 | with open(path,encoding='UTF-8') as f: 13 | module_str = f.read() 14 | exec(module_str, module.__dict__) 15 | modules[mod_name] = module 16 | 17 | dataset_dict = {} 18 | for module_name, module in modules.items(): 19 | for el in dir(module): 20 | if el.endswith("Dataset"): 21 | obj = module.__dict__[el] 22 | dataset_dict[module_name] = obj 23 | 24 | 25 | def get_dataset(name): 26 | return dataset_dict[name]() 27 | -------------------------------------------------------------------------------- /configs/retriever.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | chdir: false 4 | batch_size: 48 5 | rand_seed: 1 6 | model_name: "gpt2-xl" 7 | dpp_sampling: false 8 | scale_factor: null 9 | rerank: false 10 | window: 30 11 | num_candidates: 8 12 | num_ice: 100 13 | output_file: ??? 14 | task_name: sst2 15 | cuda_device: cuda:0 16 | method: 'mdl' 17 | force_topk: true 18 | instruction_template: 1 19 | span: true 20 | n_tokens: 700 21 | sort: false 22 | use_rand_pool: false 23 | calibrate: false 24 | prior_no: 1 25 | overwrite: true 26 | all_permutation: false 27 | dataset_reader: 28 | _target_: src.dataset_readers.retriever_dsr.RetrieverDatasetReader 29 | dataset_path: ??? 30 | task_name: ${task_name} 31 | model_name: ${model_name} 32 | index_split: "train" 33 | n_tokens: 700 34 | index_data_path: null 35 | model: 36 | _target_: transformers.AutoModelForCausalLM.from_pretrained 37 | pretrained_model_name_or_path: ${model_name} 38 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pathlib 3 | import types 4 | from os.path import dirname, isfile, join 5 | 6 | modules = {} 7 | modules_list = glob.glob(join(dirname(__file__), "*.py")) 8 | for path in modules_list: 9 | if isfile(path) and not path.endswith('__init__.py') and not path.endswith('task_.py'): 10 | mod_name = pathlib.Path(path).name[:-3] 11 | module = types.ModuleType(mod_name) 12 | with open(path, encoding='utf-8') as f: 13 | module_str = f.read() 14 | exec(module_str, module.__dict__) 15 | modules[mod_name] = module 16 | 17 | task_list = {} 18 | for module_name, module in modules.items(): 19 | for el in dir(module): 20 | if el.endswith("DatasetWrapper"): 21 | obj = module.__dict__[el] 22 | task_list[obj.name] = obj 23 | 24 | 25 | def get_dataset_wrapper(name): 26 | return task_list[name] 27 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | import random 4 | 5 | 6 | class DatasetWrapper: 7 | name = "base" 8 | 9 | def __init__(self): 10 | self.dataset = None 11 | self.field_getter = None 12 | 13 | def __getitem__(self, idx): 14 | return self.dataset[idx] 15 | 16 | def __len__(self): 17 | return len(self.dataset) 18 | 19 | def get_field(self, entry, field): 20 | return self.field_getter.functions[field](entry) 21 | 22 | def get_corpus(self, field): 23 | return [self.get_field(entry, field) for entry in self.dataset] 24 | 25 | 26 | def load_partial_dataset(dataset, size=1): 27 | if size == 1 or size >= len(dataset): 28 | return dataset 29 | 30 | total_size = len(dataset) 31 | size = int(size * total_size) if size < 1 else size 32 | 33 | rand = random.Random(x=size) 34 | index_list = list(range(total_size)) 35 | rand.shuffle(index_list) 36 | dataset = dataset.select(index_list[:size]) 37 | return dataset -------------------------------------------------------------------------------- /src/dataset_readers/prerank_dsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.dataset_readers.dataset_wrappers import get_dataset_wrapper 3 | 4 | 5 | class PrerankDatasetReader(torch.utils.data.Dataset): 6 | 7 | def __init__(self, task_name, field, dataset_path=None, dataset_split=None, 8 | ds_size=None, tokenizer=None) -> None: 9 | self.tokenizer = tokenizer 10 | self.dataset_wrapper = get_dataset_wrapper(task_name)(dataset_path=dataset_path, 11 | dataset_split=dataset_split, 12 | ds_size=ds_size) 13 | 14 | self.field = field 15 | 16 | def __getitem__(self, index): 17 | entry = self.dataset_wrapper[index] 18 | enc_text = self.dataset_wrapper.get_field(entry, self.field) 19 | 20 | tokenized_inputs = self.tokenizer.encode_plus(enc_text, truncation=True, return_tensors='pt', padding='longest') 21 | 22 | return { 23 | 24 | 'input_ids': tokenized_inputs.input_ids.squeeze(), 25 | 'attention_mask': tokenized_inputs.attention_mask.squeeze(), 26 | "metadata": {"id": index} 27 | } 28 | 29 | def __len__(self): 30 | return len(self.dataset_wrapper) -------------------------------------------------------------------------------- /src/dataset_readers/base_dsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer 3 | from src.dataset_readers.dataset_wrappers import get_dataset_wrapper 4 | 5 | 6 | class BaseDatasetReader(torch.utils.data.Dataset): 7 | 8 | def __init__(self, task_name, model_name, field, dataset_path=None, dataset_split=None, 9 | ds_size=None, tokenizer=None) -> None: 10 | self.tokenizer = tokenizer 11 | self.dataset_wrapper = get_dataset_wrapper(task_name)(dataset_path=dataset_path, 12 | dataset_split=dataset_split, 13 | ds_size=ds_size) 14 | 15 | self.field = field 16 | 17 | def __getitem__(self, index): 18 | entry = self.dataset_wrapper[index] 19 | enc_text = self.dataset_wrapper.get_field(entry, self.field) 20 | 21 | tokenized_inputs = self.tokenizer.encode_plus(enc_text, truncation=True, return_tensors='pt') 22 | 23 | return { 24 | 'input_ids': tokenized_inputs.input_ids.squeeze(), 25 | 'attention_mask': tokenized_inputs.attention_mask.squeeze(), 26 | "metadata": {"id": index} 27 | } 28 | 29 | def __len__(self): 30 | return len(self.dataset_wrapper) 31 | -------------------------------------------------------------------------------- /src/utils/cache_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | 4 | 5 | class BufferedJsonWriter(object): 6 | def __init__(self, file_name, buffer_size=50): 7 | self.file_path = file_name 8 | self.buffer = [] 9 | self.buffer_size = buffer_size 10 | 11 | def __enter__(self): 12 | return self 13 | 14 | def __exit__(self, type, value, traceback): 15 | self.write_buffer() 16 | 17 | def write(self, obj=None): 18 | if obj is not None: 19 | self.buffer.append(obj) 20 | if len(self.buffer) >= self.buffer_size: 21 | self.write_buffer() 22 | 23 | def write_buffer(self): 24 | with open(self.file_path, "a") as data_file: 25 | data_file.write(json.dumps(self.buffer)) 26 | data_file.write("\n") 27 | self.buffer = [] 28 | 29 | 30 | class BufferedJsonReader(object): 31 | def __init__(self, file_name): 32 | self.file_path = file_name 33 | 34 | def __enter__(self): 35 | return self 36 | 37 | def __exit__(self, type, value, traceback): 38 | pass 39 | 40 | def __itr__(self): 41 | with open(self.file_path, "r") as data_file: 42 | for line in data_file: 43 | yield from json.loads(line) 44 | 45 | def read(self): 46 | return list(self.__itr__()) 47 | 48 | 49 | def get_cache_path(dataset): 50 | cache_files = dataset.cache_files 51 | if isinstance(cache_files, dict): 52 | cache_files = next(iter(cache_files.values())) 53 | return pathlib.Path(cache_files[0]['filename']).parent 54 | -------------------------------------------------------------------------------- /scripts/sst2/run_prompting.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=ICL # change if needed 2 | export WANDB_ENTITY=xx # change to your wandb account 3 | export WANDB_API_KEY=xx # change to your api-key 4 | export WANDB_START_METHOD=thread 5 | export TOKENIZERS_PARALLELISM=false 6 | 7 | export HYDRA_FULL_ERROR=1 8 | 9 | port=1277 10 | num_ice=0 11 | model_name=gpt2-xl 12 | n_tokens=700 13 | inf_batch_size=12 14 | 15 | instruction_template=1 16 | span=true 17 | dataset_split="test" 18 | rand_seed=1 19 | root=/mnt/cache/wangyaoxiang/codes/adaptive 20 | emb_field=X # or ALL 21 | n_gpu=1 22 | 23 | for task_name in sst2 24 | do 25 | run_dir=${root}/output/${task_name}/${model_name}/${rand_seed}/${dataset_split} 26 | retrieve_file=${run_dir}/retrieved_0example.json 27 | pred_file=${run_dir}/pred6.json 28 | mkdir -p ${run_dir} 29 | 30 | python prerank.py output_file=${retrieve_file} \ 31 | num_ice=${num_ice} \ 32 | dataset_reader.task_name=${task_name} \ 33 | rand_seed=${rand_seed} \ 34 | dataset_reader.dataset_split=${dataset_split} \ 35 | index_reader.task_name=${task_name} \ 36 | index_file=${run_dir}/index \ 37 | scale_factor=0.1 38 | 39 | accelerate launch --num_processes ${n_gpu} --main_process_port ${port} ppl_inferencer.py \ 40 | dataset_reader.task_name=${task_name} \ 41 | rand_seed=${rand_seed} \ 42 | dataset_reader.dataset_path=${retrieve_file} \ 43 | instruction_template=${instruction_template} \ 44 | span=${span} \ 45 | dataset_reader.n_tokens=${n_tokens} \ 46 | output_file=${pred_file} \ 47 | model_name=${model_name} \ 48 | batch_size=${inf_batch_size} 49 | done 50 | -------------------------------------------------------------------------------- /scripts/sst2/run_random.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=ICL # change if needed 2 | export WANDB_ENTITY=xx # change to your wandb account 3 | export WANDB_API_KEY=xx # change to your api-key 4 | export WANDB_START_METHOD=thread 5 | export TOKENIZERS_PARALLELISM=false 6 | 7 | export HYDRA_FULL_ERROR=1 8 | 9 | port=1277 10 | # 11 | num_ice=8 12 | num_candidates=30 13 | 14 | #model_name=EleutherAI/gpt-neo-2.7B 15 | model_name=gpt2-xl 16 | n_tokens=700 17 | inf_batch_size=12 18 | prerank_method="random" 19 | instruction_template=1 20 | span=true 21 | dataset_split="test" 22 | rand_seed=1 23 | root=/mnt/cache/wangyaoxiang/codes/adaptive 24 | emb_field=X # or ALL 25 | n_gpu=1 26 | 27 | for task_name in sst2 28 | do 29 | run_dir=${root}/output/${task_name}/${model_name}/${rand_seed}/${dataset_split} 30 | retrieve_file=${run_dir}/retrieved_rand.json 31 | pred_file=${run_dir}/pred9.json 32 | mkdir -p ${run_dir} 33 | 34 | python prerank.py output_file=${retrieve_file} \ 35 | emb_field=${emb_field} \ 36 | num_ice=${num_ice} \ 37 | num_candidates=${num_candidates} \ 38 | dataset_reader.task_name=${task_name} \ 39 | rand_seed=${rand_seed} \ 40 | method=${prerank_method} \ 41 | dataset_reader.dataset_split=${dataset_split} \ 42 | index_reader.task_name=${task_name} \ 43 | index_file=${run_dir}/index \ 44 | scale_factor=0.1 45 | 46 | accelerate launch --num_processes ${n_gpu} --main_process_port ${port} ppl_inferencer.py \ 47 | dataset_reader.task_name=${task_name} \ 48 | rand_seed=${rand_seed} \ 49 | dataset_reader.dataset_path=${retrieve_file} \ 50 | instruction_template=${instruction_template} \ 51 | span=${span}\ 52 | dataset_reader.n_tokens=${n_tokens} \ 53 | output_file=${pred_file} \ 54 | model_name=${model_name} \ 55 | batch_size=${inf_batch_size} 56 | done 57 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/squad.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | 6 | field_getter = App() 7 | 8 | 9 | @field_getter.add("X") 10 | def get_X(entry): 11 | return entry['question'] if 'X' not in entry.keys() else entry['X'] 12 | 13 | 14 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 15 | def get_C(entry): 16 | return entry['context'] if 'C' not in entry.keys() else entry['C'] 17 | 18 | 19 | @field_getter.add("Y") # 用于取template 20 | def get_Y(entry): 21 | return 0 22 | 23 | 24 | @field_getter.add("Y_TEXT") 25 | def get_Y_TEXT(entry): 26 | return entry['answers']['text'][0] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 27 | 28 | 29 | @field_getter.add("ALL") 30 | def get_ALL(entry): 31 | return f"{entry['context']}\t{entry['question']}\t{entry['answers']['text']}" if 'ALL' not in entry.keys() else \ 32 | entry['ALL'] 33 | 34 | 35 | class MtopDatasetWrapper(DatasetWrapper): 36 | name = "squad" 37 | question_field = "question" 38 | answer_field = "answers" 39 | 40 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 41 | super().__init__() 42 | self.field_getter = field_getter 43 | self.postfix = "" # for inference 44 | if dataset_path is None: 45 | self.dataset = load_dataset("squad", split=dataset_split) 46 | else: 47 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 48 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 49 | self.dataset = self.dataset[dataset_split] 50 | 51 | if ds_size is not None: 52 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 53 | -------------------------------------------------------------------------------- /scripts/sst2/run_topk.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=ICL # change if needed 2 | export WANDB_ENTITY=xx # change to your wandb account 3 | export WANDB_API_KEY=xx # change to your api-key 4 | export WANDB_START_METHOD=thread 5 | export TOKENIZERS_PARALLELISM=false 6 | 7 | export HYDRA_FULL_ERROR=1 8 | 9 | port=12714 10 | # 11 | num_ice=8 12 | num_candidates=30 13 | 14 | #model_name=EleutherAI/gpt-neo-2.7B 15 | model_name=gpt2-xl 16 | n_tokens=700 17 | inf_batch_size=12 18 | instruction_template=1 19 | span=true 20 | overwrite=false # if the output_file has existed, skip the running 21 | dataset_split="test" 22 | rand_seed=1 23 | emb_field=ALL 24 | root=/mnt/cache/wangyaoxiang/codes/adaptive 25 | emb_field=X # or ALL 26 | n_gpu=1 27 | 28 | for task_name in sst2 29 | do 30 | run_dir=${root}/output/${task_name}/${model_name}/${rand_seed}/${dataset_split} 31 | retrieve_file=${run_dir}/retrieved.json 32 | pred_file=${run_dir}/pred2.json 33 | mkdir -p ${run_dir} 34 | 35 | python prerank.py output_file=${retrieve_file} \ 36 | emb_field=${emb_field} \ 37 | overwrite=${overwrite} \ 38 | num_ice=${num_ice} \ 39 | num_candidates=${num_candidates} \ 40 | dataset_reader.task_name=${task_name} \ 41 | rand_seed=${rand_seed} \ 42 | dataset_reader.dataset_split=${dataset_split} \ 43 | index_reader.task_name=${task_name} \ 44 | index_file=${run_dir}/index \ 45 | scale_factor=0.1 46 | 47 | accelerate launch --num_processes ${n_gpu}1 --main_process_port ${port} ppl_inferencer.py \ 48 | overwrite=${overwrite} \ 49 | dataset_reader.task_name=${task_name} \ 50 | rand_seed=${rand_seed} \ 51 | dataset_reader.dataset_path=${retrieve_file} \ 52 | instruction_template=${instruction_template} \ 53 | span=${span}\ 54 | dataset_reader.n_tokens=${n_tokens} \ 55 | output_file=${pred_file} \ 56 | model_name=${model_name} \ 57 | batch_size=${inf_batch_size} 58 | done 59 | -------------------------------------------------------------------------------- /src/dataset_readers/ppl_inference_cls_dsr.py: -------------------------------------------------------------------------------- 1 | from src.dataset_readers.retriever_dsr import RetrieverDatasetReader 2 | 3 | 4 | def get_length(tokenizer, text): 5 | tokenized_example = tokenizer.encode_plus(text, truncation=False, return_tensors='pt') 6 | return int(tokenized_example.input_ids.shape[1]) 7 | 8 | 9 | def set_length(example, **kwargs): 10 | tokenizer = kwargs['tokenizer'] 11 | set_field = kwargs['set_field'] 12 | field_getter = kwargs['field_getter'] 13 | 14 | field_text = field_getter.functions[set_field](example) 15 | example[f'{set_field}_len'] = get_length(tokenizer, field_text) 16 | if set_field not in example: 17 | example[set_field] = field_text 18 | return example 19 | 20 | 21 | def set_field(example, **kwargs): 22 | set_fields = ['C', 'X', 'Y', 'Y_TEXT', 'ALL'] 23 | field_getter = kwargs['field_getter'] 24 | for set_field in set_fields: 25 | field_text = field_getter.functions[set_field](example) 26 | if set_field not in example: 27 | example[set_field] = field_text 28 | return example 29 | 30 | 31 | class PPLCLSInferenceDatasetReader(RetrieverDatasetReader): 32 | 33 | def __init__(self, model_name, task_name, index_split, dataset_path, n_tokens=1600, tokenizer=None, 34 | index_data_path=None): 35 | super().__init__(model_name, task_name, index_split, dataset_path, n_tokens, tokenizer, index_data_path) 36 | 37 | def get_ctxs_inputs(self, entry): 38 | C = self.dataset_wrapper.get_field(entry, 'C') 39 | X = self.dataset_wrapper.get_field(entry, 'X') 40 | Y = self.dataset_wrapper.get_field(entry, 'Y') 41 | Y_TEXT = self.dataset_wrapper.get_field(entry, 'Y_TEXT') 42 | ctx = [self.index_dataset.dataset[i] for i in entry['ctxs']] 43 | example_list = [{'C': i['C'], 'X': i['X'], 'Y': i['Y'], 'Y_TEXT': i['Y_TEXT']} for i in ctx] 44 | 45 | return C, X, Y, Y_TEXT, example_list 46 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/boolq.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | from src.datasets.labels import get_mapping_token 6 | 7 | field_getter = App() 8 | 9 | label2text = get_mapping_token("boolq") 10 | 11 | 12 | @field_getter.add("X") 13 | def get_X(entry): 14 | return entry['question'] if 'X' not in entry.keys() else entry['X'] 15 | 16 | 17 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 18 | def get_C(entry): 19 | return entry['passage'] if 'C' not in entry.keys() else entry['C'] 20 | 21 | 22 | @field_getter.add("Y") # 用于取template 23 | def get_Y(entry): 24 | return entry['label'] if 'Y' not in entry.keys() else entry['Y'] 25 | 26 | 27 | @field_getter.add("Y_TEXT") 28 | def get_Y_TEXT(entry): 29 | return label2text[entry['label']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 30 | 31 | 32 | @field_getter.add("ALL") 33 | def get_ALL(entry): 34 | return f"{entry['question']}\n{entry['passage']}" if 'ALL' not in entry.keys() else \ 35 | entry['ALL'] 36 | 37 | 38 | class BoolQDatasetWrapper(DatasetWrapper): 39 | name = "boolq" 40 | question_field = "question" 41 | answer_field = "answers" 42 | 43 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 44 | super().__init__() 45 | self.field_getter = field_getter 46 | self.postfix = "" # for inference 47 | if dataset_path is None: 48 | self.dataset = load_dataset("super_glue", "boolq", split=dataset_split) 49 | else: 50 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 51 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 52 | self.dataset = self.dataset[dataset_split] 53 | 54 | if ds_size is not None: 55 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 56 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/commonsense_qa.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | 6 | field_getter = App() 7 | letter2num = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4} 8 | 9 | 10 | @field_getter.add("X") 11 | def get_X(entry): 12 | return entry['question'] if 'X' not in entry.keys() else entry['X'] 13 | 14 | 15 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 16 | def get_C(entry): 17 | if 'C' in entry.keys(): 18 | return entry['C'] 19 | choices = entry['choices']['text'] 20 | return f'{choices[0]}, {choices[1]}, {choices[2]}, {choices[3]} or {choices[4]}?' 21 | 22 | 23 | @field_getter.add("Y") # 用于取template 24 | def get_Y(entry): 25 | return letter2num[entry["answerKey"]] if 'Y' not in entry.keys() else entry['Y'] 26 | 27 | 28 | @field_getter.add("Y_TEXT") 29 | def get_Y_TEXT(entry): 30 | return entry['choices']['text'][entry['Y']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 31 | 32 | 33 | @field_getter.add("ALL") 34 | def get_ALL(entry): 35 | return f"" if 'ALL' not in entry.keys() else entry['ALL'] 36 | 37 | 38 | class CommonsenseQADatasetWrapper(DatasetWrapper): 39 | name = "commonsense_qa" 40 | question_field = "question" 41 | answer_field = "answerKey" 42 | 43 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 44 | super().__init__() 45 | self.field_getter = field_getter 46 | self.postfix = "" # for inference 47 | if dataset_path is None: 48 | self.dataset = load_dataset("commonsense_qa", split=dataset_split) 49 | else: 50 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 51 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 52 | self.dataset = self.dataset[dataset_split] 53 | 54 | if ds_size is not None: 55 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 56 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/qnli.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | from src.datasets.labels import get_mapping_token 6 | 7 | field_getter = App() 8 | 9 | # e.g. label2text = {0: "negative", 1: "positive"} 10 | label2text = get_mapping_token("qnli") 11 | 12 | 13 | @field_getter.add("X") 14 | def get_X(entry): 15 | return entry['question'] if 'X' not in entry.keys() else entry['X'] 16 | 17 | 18 | @field_getter.add("Y_TEXT") 19 | def get_Y_TEXT(entry): 20 | return label2text[entry['label']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 21 | 22 | 23 | @field_getter.add("C") 24 | def get_C(entry): 25 | return entry['sentence'] if 'C' not in entry.keys() else entry['C'] 26 | 27 | 28 | @field_getter.add("Y") 29 | def get_Y(entry): 30 | return entry['label'] if 'Y' not in entry.keys() else entry['Y'] 31 | 32 | 33 | @field_getter.add("ALL") 34 | def get_ALL(entry): 35 | return f"{entry['question']}\t {entry['sentence']}" if 'ALL' not in entry.keys() else entry['ALL'] 36 | 37 | 38 | class QNLIDatasetWrapper(DatasetWrapper): 39 | name = "qnli" 40 | 41 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 42 | def _abs_label(ex): 43 | ex['label'] = abs(ex['label']) 44 | return ex 45 | super().__init__() 46 | self.task_name = "qnli" 47 | self.field_getter = field_getter 48 | self.postfix = "" # for inference 49 | if dataset_path is None: 50 | self.dataset = load_dataset("glue", "qnli", split=dataset_split) 51 | self.dataset = self.dataset.map(_abs_label, batched=False, load_from_cache_file=False) 52 | else: 53 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 54 | 55 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 56 | self.dataset = self.dataset[dataset_split] 57 | 58 | if ds_size is not None: 59 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 60 | -------------------------------------------------------------------------------- /src/utils/model_util.py: -------------------------------------------------------------------------------- 1 | from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch 2 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 3 | import torch 4 | 5 | 6 | def load_tokenizer_model(checkpoint, pad_trunc_right=True, use_fast=True): 7 | if checkpoint == 'opt175b': 8 | model = None 9 | checkpoint = 'facebook/opt-30b' # all opt tokenizers are same 10 | use_fast = False # the fast tokenizer currently does not work correctly 11 | elif 'opt' in checkpoint: 12 | model = load_opt_model(checkpoint) 13 | use_fast = False # the fast tokenizer currently does not work correctly 14 | else: 15 | model = AutoModelForCausalLM.from_pretrained(checkpoint) 16 | if torch.cuda.is_available(): 17 | model.parallelize() 18 | 19 | if pad_trunc_right: 20 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=use_fast) 21 | else: 22 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, padding_side='left', 23 | truncation_side='left', use_fast=use_fast) 24 | 25 | tokenizer.pad_token = tokenizer.eos_token # original pad token id is None, not in embedding matrix 26 | if model is not None: 27 | model.config.pad_token_id = tokenizer.eos_token_id 28 | return tokenizer, model 29 | 30 | 31 | def load_opt_model(checkpoint): 32 | config = AutoConfig.from_pretrained(checkpoint) 33 | 34 | # Initializes an empty shell with the model. This is instant and does not take any RAM. 35 | with init_empty_weights(): 36 | model = AutoModelForCausalLM.from_config(config) 37 | # Initialize the model under the previous context manager breaks the tied weights. 38 | model.tie_weights() 39 | 40 | # Infer device map automatically 41 | device_map = infer_auto_device_map(model.model, no_split_module_classes=["OPTDecoderLayer"], dtype='float16') 42 | print(device_map) 43 | 44 | load_checkpoint_and_dispatch( 45 | model.model, 46 | checkpoint, 47 | device_map=device_map, 48 | offload_folder=None, 49 | dtype='float16', 50 | offload_state_dict=True 51 | ) 52 | model.tie_weights() 53 | return model -------------------------------------------------------------------------------- /src/utils/collators.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | import torch 5 | from transformers import PreTrainedTokenizerBase, BatchEncoding 6 | from transformers.file_utils import PaddingStrategy 7 | 8 | 9 | 10 | class ListWrapper: 11 | def __init__(self, data: List[Any]): 12 | self.data = data 13 | 14 | def to(self, device): 15 | return self.data 16 | 17 | 18 | def ignore_pad_dict(features): 19 | res_dict = {} 20 | if "metadata" in features[0]: 21 | res_dict['metadata'] = ListWrapper([x.pop("metadata") for x in features]) 22 | 23 | if "id2idx" in features[0]: 24 | res_dict['id2idx'] = ListWrapper([x.pop("idx2id") for x in features]) 25 | return res_dict 26 | 27 | 28 | @dataclass 29 | class DataCollatorWithPaddingAndCuda: 30 | tokenizer: PreTrainedTokenizerBase 31 | device: object = None 32 | padding: Union[bool, str, PaddingStrategy] = True 33 | max_length: Optional[int] = 3000 34 | pad_to_multiple_of: Optional[int] = None 35 | 36 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> BatchEncoding: 37 | res_dict = ignore_pad_dict(features) 38 | # print(self.device) 39 | has_labels = "labels" in features[0] 40 | if has_labels: 41 | labels = [{"input_ids": x.pop("labels")} for x in features] 42 | labels = self.tokenizer.pad( 43 | labels, 44 | padding=True, 45 | max_length=self.max_length, 46 | pad_to_multiple_of=self.pad_to_multiple_of, 47 | return_attention_mask=True, 48 | return_tensors="pt", 49 | ) 50 | 51 | # print(features) 52 | batch = self.tokenizer.pad( 53 | features, 54 | padding=True, 55 | max_length=self.max_length, 56 | pad_to_multiple_of=self.pad_to_multiple_of, 57 | return_attention_mask=True, 58 | return_tensors="pt", 59 | ) 60 | 61 | if has_labels: 62 | batch['labels'] = labels.input_ids 63 | batch.update(res_dict) 64 | 65 | if self.device: 66 | batch = batch.to(self.device) 67 | 68 | return batch 69 | 70 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/sst5.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | from src.datasets.labels import get_mapping_token 6 | 7 | field_getter = App() 8 | 9 | # e.g. label2text = {0: "negative", 1: "positive"} 10 | label2text = get_mapping_token("sst5") 11 | 12 | 13 | @field_getter.add("X") 14 | def get_X(entry): 15 | return entry['text'] if 'X' not in entry.keys() else entry['X'] 16 | 17 | 18 | @field_getter.add("Y_TEXT") # 获得标签对应的文本 19 | def get_Y_TEXT(entry): 20 | return label2text[entry['label']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 21 | 22 | 23 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 24 | def get_C(entry): 25 | return "" if 'C' not in entry.keys() else entry['C'] 26 | 27 | 28 | @field_getter.add("Y") # 获得原始标签 int 29 | def get_Y(entry): 30 | return entry['label'] if 'Y' not in entry.keys() else entry['Y'] 31 | 32 | 33 | @field_getter.add("ALL") 34 | def get_ALL(entry): 35 | return f"{entry['text']}\tIt is {label2text[entry['label']]}" if 'ALL' not in entry.keys() else entry['ALL'] 36 | 37 | 38 | class SST2DatasetWrapper(DatasetWrapper): 39 | name = "sst5" 40 | sentence_field = "text" 41 | label_field = "label" 42 | 43 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 44 | def _abs_label(ex): 45 | ex['label'] = abs(ex['label']) 46 | return ex 47 | 48 | super().__init__() 49 | self.task_name = "sst5" 50 | self.field_getter = field_getter 51 | self.postfix = "It is" # for inference 52 | if dataset_path is None: 53 | self.dataset = load_dataset("SetFit/sst5", split=dataset_split) 54 | self.dataset = self.dataset.map(_abs_label, batched=False, load_from_cache_file=False) 55 | else: 56 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 57 | 58 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 59 | self.dataset = self.dataset[dataset_split] 60 | 61 | if ds_size is not None: 62 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 63 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/ag_news.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | from src.datasets.labels import get_mapping_token 6 | 7 | field_getter = App() 8 | 9 | # e.g. label2text = {0: "negative", 1: "positive"} 10 | label2text = get_mapping_token("ag_news") 11 | 12 | 13 | @field_getter.add("X") 14 | def get_X(entry): 15 | return entry['text'] if 'X' not in entry.keys() else entry['X'] 16 | 17 | 18 | @field_getter.add("Y_TEXT") # 获得标签对应的文本 19 | def get_Y_TEXT(entry): 20 | return label2text[entry['label']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 21 | 22 | 23 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 24 | def get_C(entry): 25 | return "" if 'C' not in entry.keys() else entry['C'] 26 | 27 | 28 | @field_getter.add("Y") # 获得原始标签 int 29 | def get_Y(entry): 30 | return entry['label'] if 'Y' not in entry.keys() else entry['Y'] 31 | 32 | 33 | @field_getter.add("ALL") 34 | def get_ALL(entry): 35 | return f"{entry['text']}\tIt is {label2text[entry['label']]}" if 'ALL' not in entry.keys() else entry['ALL'] 36 | 37 | 38 | class AGNEWSDatasetWrapper(DatasetWrapper): 39 | name = "ag_news" 40 | sentence_field = "text" 41 | label_field = "label" 42 | 43 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 44 | def _abs_label(ex): 45 | ex['label'] = abs(ex['label']) 46 | return ex 47 | 48 | super().__init__() 49 | self.task_name = "ag_news" 50 | self.field_getter = field_getter 51 | self.postfix = "It is" # for inference 52 | if dataset_path is None: 53 | self.dataset = load_dataset("ag_news", split=dataset_split) 54 | self.dataset = self.dataset.map(_abs_label, batched=False, load_from_cache_file=False) 55 | else: 56 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 57 | 58 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 59 | self.dataset = self.dataset[dataset_split] 60 | 61 | if ds_size is not None: 62 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 63 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/trec.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | from src.datasets.labels import get_mapping_token 6 | 7 | field_getter = App() 8 | 9 | # e.g. label2text = {0: "negative", 1: "positive"} 10 | label2text = get_mapping_token("trec") 11 | 12 | 13 | @field_getter.add("X") 14 | def get_X(entry): 15 | return entry['text'] if 'X' not in entry.keys() else entry['X'] 16 | 17 | 18 | @field_getter.add("Y_TEXT") # 获得标签对应的文本 19 | def get_Y_TEXT(entry): 20 | return label2text[entry['label-coarse']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 21 | 22 | 23 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 24 | def get_C(entry): 25 | return "" if 'C' not in entry.keys() else entry['C'] 26 | 27 | 28 | @field_getter.add("Y") # 获得原始标签 int 29 | def get_Y(entry): 30 | return entry['label-coarse'] if 'Y' not in entry.keys() else entry['Y'] 31 | 32 | 33 | @field_getter.add("ALL") 34 | def get_ALL(entry): 35 | return f"{entry['text']}\tIt is {label2text[entry['label-coarse']]}" if 'ALL' not in entry.keys() else entry['ALL'] 36 | 37 | 38 | class TRECDatasetWrapper(DatasetWrapper): 39 | name = "trec" 40 | sentence_field = "text" 41 | label_field = "label-coarse" 42 | 43 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 44 | def _abs_label(ex): 45 | ex['label-coarse'] = abs(ex['label-coarse']) 46 | return ex 47 | 48 | super().__init__() 49 | self.task_name = "trec" 50 | self.field_getter = field_getter 51 | self.postfix = "It is" # for inference 52 | if dataset_path is None: 53 | self.dataset = load_dataset("trec", split=dataset_split) 54 | self.dataset = self.dataset.map(_abs_label, batched=False, load_from_cache_file=False) 55 | else: 56 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 57 | 58 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 59 | self.dataset = self.dataset[dataset_split] 60 | 61 | if ds_size is not None: 62 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 63 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/sst2.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | from src.datasets.labels import get_mapping_token 6 | 7 | field_getter = App() 8 | 9 | # e.g. label2text = {0: "negative", 1: "positive"} 10 | label2text = get_mapping_token("sst2") 11 | 12 | 13 | @field_getter.add("X") 14 | def get_X(entry): 15 | return entry['text'] if 'X' not in entry.keys() else entry['X'] 16 | 17 | 18 | @field_getter.add("Y_TEXT") # 获得标签对应的文本 19 | def get_Y_TEXT(entry): 20 | return label2text[entry['label']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 21 | 22 | 23 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 24 | def get_C(entry): 25 | return "" if 'C' not in entry.keys() else entry['C'] 26 | 27 | 28 | @field_getter.add("Y") # 获得原始标签 int 29 | def get_Y(entry): 30 | return entry['label'] if 'Y' not in entry.keys() else entry['Y'] 31 | 32 | 33 | @field_getter.add("ALL") 34 | def get_ALL(entry): 35 | return f"{entry['text']}\tIt is {label2text[entry['label']]}" if 'ALL' not in entry.keys() else entry['ALL'] 36 | 37 | 38 | class SST2DatasetWrapper(DatasetWrapper): 39 | name = "sst2" 40 | sentence_field = "text" 41 | label_field = "label" 42 | 43 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 44 | def _reverse_label(ex): # 改为 0neg 1pos 只在sst2做 45 | ex['label'] = abs(ex['label'] - 1) 46 | return ex 47 | 48 | super().__init__() 49 | self.task_name = "sst2" 50 | self.field_getter = field_getter 51 | self.postfix = "It is" # for inference 52 | if dataset_path is None: 53 | self.dataset = load_dataset("gpt3mix/sst2", split=dataset_split) 54 | self.dataset = self.dataset.map(_reverse_label, batched=False, load_from_cache_file=False) 55 | else: 56 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 57 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 58 | self.dataset = self.dataset[dataset_split] 59 | 60 | if ds_size is not None: 61 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 62 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/rte.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | from src.datasets.labels import get_mapping_token 6 | 7 | field_getter = App() 8 | 9 | # e.g. label2text = {0: "negative", 1: "positive"} 10 | label2text = get_mapping_token("rte") 11 | 12 | 13 | @field_getter.add("X") 14 | def get_X(entry): 15 | return entry['sentence1'] if 'X' not in entry.keys() else entry['X'] 16 | 17 | 18 | @field_getter.add("Y_TEXT") # 获得标签对应的文本 19 | def get_Y_TEXT(entry): 20 | # print(entry) 21 | return label2text[entry['label']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 22 | 23 | 24 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 25 | def get_C(entry): 26 | return entry['sentence2'] if 'C' not in entry.keys() else entry['C'] 27 | 28 | 29 | @field_getter.add("Y") # 获得原始标签 int 30 | def get_Y(entry): 31 | return entry['label'] if 'Y' not in entry.keys() else entry['Y'] 32 | 33 | 34 | @field_getter.add("ALL") 35 | def get_ALL(entry): 36 | return f"{entry['sentence1']}\t {entry['sentence2']}" if 'ALL' not in entry.keys() else entry['ALL'] 37 | 38 | 39 | class RTEDatasetWrapper(DatasetWrapper): 40 | name = "rte" 41 | premise_field = "premise" 42 | hypothesis_field = "hypothesis" 43 | label_field = "label" 44 | 45 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 46 | def _abs_label(ex): 47 | ex['label'] = abs(ex['label']) 48 | return ex 49 | super().__init__() 50 | self.task_name = "rte" 51 | self.field_getter = field_getter 52 | self.postfix = "" # for inference 53 | if dataset_path is None: 54 | self.dataset = load_dataset("glue", "rte", split=dataset_split) 55 | self.dataset = self.dataset.map(_abs_label, batched=False, load_from_cache_file=False) 56 | else: 57 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 58 | 59 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 60 | self.dataset = self.dataset[dataset_split] 61 | 62 | if ds_size is not None: 63 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 64 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/snli.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | from src.datasets.labels import get_mapping_token 6 | 7 | field_getter = App() 8 | 9 | # e.g. label2text = {0: "negative", 1: "positive"} 10 | label2text = get_mapping_token("mnli") 11 | 12 | 13 | @field_getter.add("X") 14 | def get_X(entry): 15 | return entry['premise'] if 'X' not in entry.keys() else entry['X'] 16 | 17 | 18 | @field_getter.add("Y_TEXT") # 获得标签对应的文本 19 | def get_Y_TEXT(entry): 20 | # print(entry) 21 | return label2text[entry['label']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 22 | 23 | 24 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 25 | def get_C(entry): 26 | return entry['hypothesis'] if 'C' not in entry.keys() else entry['C'] 27 | 28 | 29 | @field_getter.add("Y") # 获得原始标签 int 30 | def get_Y(entry): 31 | return entry['label'] if 'Y' not in entry.keys() else entry['Y'] 32 | 33 | 34 | @field_getter.add("ALL") 35 | def get_ALL(entry): 36 | return f"{entry['hypothesis']}\t {entry['premise']}" if 'ALL' not in entry.keys() else entry['ALL'] 37 | 38 | 39 | class SNLIDatasetWrapper(DatasetWrapper): 40 | name = "snli" 41 | premise_field = "premise" 42 | hypothesis_field = "hypothesis" 43 | label_field = "label" 44 | 45 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 46 | def _abs_label(ex): 47 | ex['label'] = abs(ex['label']) 48 | return ex 49 | 50 | super().__init__() 51 | self.task_name = "sst2" 52 | self.field_getter = field_getter 53 | self.postfix = "" # for inference 54 | if dataset_path is None: 55 | self.dataset = load_dataset("snli", split=dataset_split) 56 | self.dataset = self.dataset.map(_abs_label, batched=False, load_from_cache_file=False) 57 | else: 58 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 59 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 60 | self.dataset = self.dataset[dataset_split] 61 | 62 | if ds_size is not None: 63 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 64 | -------------------------------------------------------------------------------- /scripts/sst2/run_local_e.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=ICL # change if needed 2 | export WANDB_ENTITY=xx # change to your wandb account 3 | export WANDB_API_KEY=xx # change to your api-key 4 | export WANDB_START_METHOD=thread 5 | export TOKENIZERS_PARALLELISM=false 6 | 7 | export HYDRA_FULL_ERROR=1 8 | 9 | port=12715 10 | # 11 | num_ice=8 12 | num_candidates=30 13 | 14 | model_name=gpt2-xl 15 | n_tokens=700 16 | inf_batch_size=12 17 | score_method="entropy" 18 | force_topk=false 19 | instruction_template=1 20 | span=true 21 | window=10 22 | dataset_split="test" 23 | rand_seed=1 24 | root=/mnt/cache/wangyaoxiang/codes/adaptive 25 | emb_field=X # or ALL 26 | n_gpu=1 27 | 28 | for task_name in sst2 29 | do 30 | run_dir=${root}/output/${task_name}/${model_name}/${rand_seed}/${dataset_split} 31 | retrieve_file=${run_dir}/retrieved.json 32 | retrieve_file2=${run_dir}/retrieved2_entropy.json 33 | pred_file=${run_dir}/pred5.json 34 | mkdir -p ${run_dir} 35 | 36 | python prerank.py output_file=${retrieve_file} \ 37 | emb_field=${emb_field} \ 38 | num_ice=${num_ice} \ 39 | num_candidates=${num_candidates} \ 40 | dataset_reader.task_name=${task_name} \ 41 | rand_seed=${rand_seed} \ 42 | dataset_reader.dataset_split=${dataset_split} \ 43 | index_reader.task_name=${task_name} \ 44 | index_file=${run_dir}/index \ 45 | scale_factor=0.1 46 | 47 | accelerate launch --num_processes ${n_gpu} --main_process_port ${port} retriever.py output_file=${retrieve_file2} \ 48 | num_ice=${num_ice} \ 49 | window=${window} \ 50 | rand_seed=${rand_seed} \ 51 | force_topk=${force_topk} \ 52 | instruction_template=${instruction_template} \ 53 | span=${span}\ 54 | dataset_reader.task_name=${task_name} \ 55 | dataset_reader.dataset_path=${retrieve_file} \ 56 | batch_size=${inf_batch_size} \ 57 | method=${score_method} 58 | 59 | accelerate launch --num_processes ${n_gpu} --main_process_port ${port} ppl_inferencer.py \ 60 | dataset_reader.task_name=${task_name} \ 61 | rand_seed=${rand_seed} \ 62 | dataset_reader.dataset_path=${retrieve_file2} \ 63 | instruction_template=${instruction_template} \ 64 | span=${span} \ 65 | dataset_reader.n_tokens=${n_tokens} \ 66 | output_file=${pred_file} \ 67 | model_name=${model_name} \ 68 | batch_size=${inf_batch_size} 69 | done 70 | -------------------------------------------------------------------------------- /scripts/sst2/run_mdl.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=ICL # change if needed 2 | export WANDB_ENTITY=xx # change to your wandb account 3 | export WANDB_API_KEY=xx # change to your api-key 4 | export WANDB_START_METHOD=thread 5 | export TOKENIZERS_PARALLELISM=false 6 | 7 | export HYDRA_FULL_ERROR=1 8 | 9 | root=/mnt/cache/wangyaoxiang/codes/adaptive 10 | num_ice=8 11 | num_candidates=30 12 | 13 | prerank_method='topk' 14 | score_method='mdl' 15 | 16 | model_name=gpt2-xl 17 | n_tokens=700 18 | inf_batch_size=12 19 | instruction_template=1 20 | span=true 21 | window=10 22 | dataset_split="test" 23 | rand_seed=1 24 | port=12715 25 | emb_field=ALL 26 | emb_field=X # or ALL 27 | n_gpu=1 28 | 29 | 30 | for task_name in sst2 31 | do 32 | run_dir=${root}/output/${task_name}/${model_name}/${rand_seed}/${dataset_split} 33 | retrieve_file=${run_dir}/retrieved.json 34 | retrieve_file2=${run_dir}/retrieved2.json 35 | pred_file=${run_dir}/pred4.json 36 | mkdir -p ${run_dir} 37 | 38 | python prerank.py output_file=${retrieve_file} \ 39 | emb_field=${emb_field} \ 40 | num_ice=${num_ice} \ 41 | method=${prerank_method} \ 42 | num_candidates=${num_candidates} \ 43 | dataset_reader.task_name=${task_name} \ 44 | rand_seed=${rand_seed} \ 45 | dataset_reader.dataset_split=${dataset_split} \ 46 | index_reader.task_name=${task_name} \ 47 | index_file=${run_dir}/index \ 48 | scale_factor=0.1 49 | 50 | accelerate launch --num_processes ${n_gpu} --main_process_port ${port} retriever.py output_file=${retrieve_file2} \ 51 | num_ice=${num_ice} \ 52 | window=${window} \ 53 | rand_seed=${rand_seed} \ 54 | instruction_template=${instruction_template} \ 55 | span=${span}\ 56 | dataset_reader.task_name=${task_name} \ 57 | dataset_reader.dataset_path=${retrieve_file} \ 58 | batch_size=${inf_batch_size} \ 59 | method=${score_method} 60 | 61 | accelerate launch --num_processes ${n_gpu} --main_process_port ${port} ppl_inferencer.py \ 62 | dataset_reader.task_name=${task_name} \ 63 | rand_seed=${rand_seed} \ 64 | dataset_reader.dataset_path=${retrieve_file2} \ 65 | instruction_template=${instruction_template} \ 66 | span=${span}\ 67 | dataset_reader.n_tokens=${n_tokens} \ 68 | output_file=${pred_file} \ 69 | model_name=${model_name} \ 70 | batch_size=${inf_batch_size} 71 | done 72 | -------------------------------------------------------------------------------- /src/dataset_readers/dataset_wrappers/mnli.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset, DatasetDict 2 | import pandas as pd 3 | from src.utils.app import App 4 | from src.dataset_readers.dataset_wrappers.base import * 5 | from src.datasets.labels import get_mapping_token 6 | 7 | field_getter = App() 8 | 9 | # e.g. label2text = {0: "negative", 1: "positive"} 10 | label2text = get_mapping_token("mnli") 11 | 12 | 13 | @field_getter.add("X") 14 | def get_X(entry): 15 | return entry['premise'] if 'X' not in entry.keys() else entry['X'] 16 | 17 | 18 | @field_getter.add("Y_TEXT") # 获得标签对应的文本 19 | def get_Y_TEXT(entry): 20 | # if entry['label']==-1: 21 | # print(entry) 22 | return label2text[entry['label']] if 'Y_TEXT' not in entry.keys() else entry['Y_TEXT'] 23 | 24 | 25 | @field_getter.add("C") # 数据集有sentence1,sentence2,label 时,用C表示sentence1 26 | def get_C(entry): 27 | return entry['hypothesis'] if 'C' not in entry.keys() else entry['C'] 28 | 29 | 30 | @field_getter.add("Y") # 获得原始标签 int 31 | def get_Y(entry): 32 | return entry['label'] if 'Y' not in entry.keys() else entry['Y'] 33 | 34 | 35 | @field_getter.add("ALL") 36 | def get_ALL(entry): 37 | return f"{entry['hypothesis']}\t {entry['premise']}" if 'ALL' not in entry.keys() else entry['ALL'] 38 | 39 | 40 | class MNLIDatasetWrapper(DatasetWrapper): 41 | name = "mnli" 42 | premise_field = "premise" 43 | hypothesis_field = "hypothesis" 44 | label_field = "label" 45 | 46 | def __init__(self, dataset_path=None, dataset_split=None, ds_size=None): 47 | def _abs_label(ex): 48 | ex['label'] = abs(ex['label']) 49 | return ex 50 | super().__init__() 51 | self.task_name = "sst2" 52 | self.field_getter = field_getter 53 | self.postfix = "" # for inference 54 | if dataset_path is None: 55 | self.dataset = load_dataset("LysandreJik/glue-mnli-train",split=dataset_split) 56 | self.dataset = self.dataset.map(_abs_label, batched=False, load_from_cache_file=False) 57 | else: 58 | self.dataset = Dataset.from_pandas(pd.DataFrame(data=pd.read_json(dataset_path))) 59 | if dataset_split is not None and isinstance(self.dataset, DatasetDict): 60 | self.dataset = self.dataset[dataset_split] 61 | 62 | if ds_size is not None: 63 | self.dataset = load_partial_dataset(self.dataset, size=ds_size) 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ -------------------------------------------------------------------------------- /src/datasets/labels.py: -------------------------------------------------------------------------------- 1 | 2 | def get_mapping_set(task_name): 3 | if task_name == 'sst2' or task_name == 'yelp_polarity' or task_name == "imdb": 4 | return sst2_map_set 5 | 6 | 7 | def get_mapping_token(task_name): 8 | r = get_mapping_token0(task_name) 9 | 10 | if r is None: 11 | return {i: str(i) for i in range(100)} # 100分类应该够了 12 | return r 13 | 14 | 15 | def get_mapping_token0(task_name): 16 | if task_name == 'sst2' or task_name == 'yelp_polarity' or task_name == "imdb": 17 | return sst2_map_token 18 | if task_name == 'mnli' or task_name == "snli": 19 | return mnli_map_token 20 | if task_name == 'rte' or task_name == 'qnli' or task_name == 'mrpc': 21 | return binary_nli_map_token 22 | if task_name == 'sst5': 23 | return sst5_map_token 24 | if task_name == 'ag_news': 25 | return ag_news_map_token 26 | if task_name == 'trec': 27 | return trec_map_token 28 | return None 29 | 30 | 31 | def get_mapping_idx(task_name): 32 | idx2token = get_mapping_token(task_name) 33 | return {value: key for key, value in idx2token.items()} 34 | 35 | 36 | data_path = {'sst2': ["gpt3mix/sst2", None], 37 | 'mrpc': ["glue", "mrpc"], 38 | 'snli': ['snli', None], 39 | 'mnli': ['LysandreJik/glue-mnli-train', None], 40 | 'iwslt17zh': ["iwslt2017", 'iwslt2017-zh-en'], 41 | "mtop": ["iohadrubin/mtop", 'mtop'], 42 | "qnli": ["glue", "qnli"], 43 | "rte": ["glue", "rte"], 44 | "squad": ["squad", None], 45 | "sst5": ["SetFit/sst5", None], 46 | "ag_news": ["ag_news", None], 47 | "trec": ["trec", None], 48 | "commonsense_qa": ["commonsense_qa",None], 49 | "copa":['super_glue', 'copa'], 50 | "boolq": ['super_glue', 'boolq'] 51 | } 52 | 53 | 54 | def get_datapath(task_name): 55 | return data_path[task_name] 56 | 57 | 58 | sst2_map_set = {0: [" negative", " bad", " terrible", "negative", "bad", "terrible"], 59 | 1: [" positive", " good", " great", "positive", "good", "great"]} 60 | sst2_map_token = {0: "terrible", 1: "great"} 61 | mnli_map_token = {0: "No", 1: "Maybe", 2: 'Yes'} 62 | binary_nli_map_token = {0: 'Yes', 1: "No"} 63 | sst5_map_token = {0: "terrible", 1: "bad", 2: "okay", 3: "good", 4: "great"} 64 | ag_news_map_token = {0: "world", 1: "sports", 2: "business", 3: "technology"} 65 | trec_map_token = {0: "abbreviation", 1: "entity", 2: "description and abstract concept", 3: "human being", 66 | 4: "location", 5: "numeric value"} -------------------------------------------------------------------------------- /scripts/run_mdl.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=ICL # change if needed 2 | export WANDB_ENTITY=xx # change to your wandb account 3 | export WANDB_API_KEY=xx # change to your api-key 4 | export WANDB_START_METHOD=thread 5 | export TOKENIZERS_PARALLELISM=false 6 | 7 | export HYDRA_FULL_ERROR=1 8 | 9 | root=/mnt/cache/wangyaoxiang/codes/adaptive # change to your path 10 | num_ice=8 11 | num_candidates=30 12 | 13 | prerank_method='topk' # random votek dpp 14 | score_method='mdl' # entropy 15 | 16 | model_name=gpt2-xl 17 | n_tokens=700 18 | inf_batch_size=12 19 | instruction_template=1 20 | span=true 21 | window=10 22 | dataset_split="test" 23 | rand_seed=1 24 | port=12715 25 | # X for sst2,sst5,trec,agnew ALL for snli,mnli,qnli,commonsense_qa 26 | emb_field=X # or ALL 27 | n_gpu=1 28 | 29 | for task_name in sst2 30 | do 31 | 32 | if [ ${task_name} = 'commonsense_qa' ] || [ ${task_name} = 'mnli' ] || [ ${task_name} = 'qnli' ]; 33 | then 34 | dataset_split='validation' 35 | else 36 | dataset_split='test' 37 | fi 38 | 39 | if [ ${task_name} = 'snli' ] || [ ${task_name} = 'mnli' ] || [ ${task_name} = 'qnli' ]; 40 | then 41 | emb_field=ALL 42 | else 43 | emb_field=X 44 | fi 45 | 46 | run_dir=${root}/output/${task_name}/${model_name}/${rand_seed}/${dataset_split} 47 | retrieve_file=${run_dir}/retrieved.json 48 | retrieve_file2=${run_dir}/retrieved2.json 49 | pred_file=${run_dir}/pred4.json 50 | mkdir -p ${run_dir} 51 | 52 | python prerank.py output_file=${retrieve_file} \ 53 | emb_field=${emb_field} \ 54 | num_ice=${num_ice} \ 55 | method=${prerank_method} \ 56 | num_candidates=${num_candidates} \ 57 | dataset_reader.task_name=${task_name} \ 58 | rand_seed=${rand_seed} \ 59 | dataset_reader.dataset_split=${dataset_split} \ 60 | index_reader.task_name=${task_name} \ 61 | index_file=${run_dir}/index \ 62 | scale_factor=0.1 63 | 64 | accelerate launch --num_processes ${n_gpu} --main_process_port ${port} retriever.py output_file=${retrieve_file2} \ 65 | num_ice=${num_ice} \ 66 | window=${window} \ 67 | rand_seed=${rand_seed} \ 68 | instruction_template=${instruction_template} \ 69 | span=${span}\ 70 | dataset_reader.task_name=${task_name} \ 71 | dataset_reader.dataset_path=${retrieve_file} \ 72 | batch_size=${inf_batch_size} \ 73 | method=${score_method} 74 | 75 | accelerate launch --num_processes ${n_gpu} --main_process_port ${port} ppl_inferencer.py \ 76 | dataset_reader.task_name=${task_name} \ 77 | rand_seed=${rand_seed} \ 78 | dataset_reader.dataset_path=${retrieve_file2} \ 79 | instruction_template=${instruction_template} \ 80 | span=${span}\ 81 | dataset_reader.n_tokens=${n_tokens} \ 82 | output_file=${pred_file} \ 83 | model_name=${model_name} \ 84 | batch_size=${inf_batch_size} 85 | done 86 | -------------------------------------------------------------------------------- /src/utils/calculate.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | 4 | import numpy as np 5 | 6 | 7 | def entropy(probs: np.array, label_dim: int = 0, mask=None): 8 | if mask is None: 9 | return - (probs * np.log(probs)).sum(label_dim) 10 | return - (mask * probs * np.log(probs)).sum(label_dim) 11 | 12 | 13 | def mdl(probs: np.array, label_dim: int = 0, mask=None): 14 | if mask is None: 15 | return - (np.log(probs)).sum(label_dim) 16 | return - (mask * np.log(probs)).sum(label_dim) 17 | 18 | 19 | # dict_list2list_dict 20 | def dict_list2list_dict(m): 21 | return [{key: m[key][j] for key in m.keys()} 22 | for j in range(len(m[list(m.keys())[0]]))] 23 | 24 | 25 | def reshape_examples(examples): # 8*{X:12,Y:12...}->12*{X:8,Y:8...}->12*[8*{X,Y}] 26 | len1 = len(examples) # example num 27 | if (len1 > 0): 28 | len2 = len(examples[0]['X']) # batchsize 29 | else: 30 | len2 = 0 31 | _examples = [] 32 | for j in range(len2): 33 | example = {'C': [examples[k]['C'][j] for k in range(len1)], 34 | 'X': [examples[k]['X'][j] for k in range(len1)], 35 | 'Y': [examples[k]['Y'][j] for k in range(len1)], 36 | 'Y_TEXT': [examples[k]['Y_TEXT'][j] for k in range(len1)]} 37 | _examples.append(dict_list2list_dict(example)) 38 | 39 | return _examples 40 | 41 | 42 | def transform(_metadata): 43 | metadata_tmp = copy.deepcopy(_metadata) 44 | metadata_tmp["examples"] = reshape_examples(_metadata["examples"]) 45 | if (len(metadata_tmp["examples"]) == 0): 46 | metadata_tmp["examples"] = [[] for i in range(len(metadata_tmp['X']))] 47 | return dict_list2list_dict(metadata_tmp) 48 | 49 | 50 | def get_global_entropy(preds, label_num): 51 | count = [1 for i in range(label_num)] # 防止0导致的NAN 52 | for pred in preds: 53 | count[pred] += 1 54 | probs = np.array(count) / len(preds) 55 | return entropy(probs) 56 | 57 | 58 | def get_local_entropy(probs): 59 | return np.vectorize(lambda x: entropy(x))(np.array(probs)).sum() / len(probs) 60 | 61 | 62 | def get_mi2(probs): 63 | conditional_e = get_local_entropy(probs) 64 | preds = np.array(probs).argmax(axis=0) 65 | e = get_global_entropy(preds, len(probs[0])) 66 | mi = e - conditional_e 67 | return mi 68 | 69 | 70 | # def get_mi(probs): 71 | # return get_mi2(probs) 72 | 73 | def get_mi(probs): # 这种不适用与ppl inferencer,效果很差 74 | conditional_e = get_local_entropy(probs) 75 | global_probs = np.array(probs).sum(axis=0) 76 | global_probs = global_probs / global_probs.sum() 77 | e = entropy(global_probs) 78 | mi = e - conditional_e 79 | return mi 80 | 81 | 82 | def get_permutations(num): 83 | array = [i for i in range(num)] 84 | permutation = list(itertools.permutations(array)) 85 | permutation = [list(t) for t in permutation] 86 | return permutation # list int 87 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from multiprocessing import Pool, TimeoutError 5 | from tqdm import tqdm 6 | from functools import partial 7 | import json 8 | import logging 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class App: 15 | def __init__(self, dict_funcs=None): 16 | self.functions = {} 17 | if dict_funcs is not None: 18 | self.functions.update(dict_funcs) 19 | 20 | def add(self, key): 21 | def adder(func): 22 | self.functions[key] = func 23 | return func 24 | 25 | return adder 26 | 27 | def __getitem__(self, __name: str): 28 | return self.functions[__name] 29 | 30 | def merge(self, app): 31 | new_app = App() 32 | new_app.functions = self.functions.update(app.functions) 33 | return new_app 34 | 35 | 36 | def wrapper(idx_args, func): 37 | idx, args = idx_args 38 | res = func(args) 39 | return idx, res 40 | 41 | 42 | def parallel_run(func, args_list, n_processes=8, initializer=None, **kwargs): 43 | idx2res = {} 44 | func = partial(func, **kwargs) 45 | n = len(args_list) 46 | with Pool(n_processes, initializer=initializer) as p: 47 | for idx, response in tqdm(p.imap_unordered(partial(wrapper, func=func), 48 | enumerate(args_list)), 49 | total=n): 50 | idx2res[idx] = response 51 | 52 | res = [idx2res[i] for i in range(n)] 53 | return res 54 | 55 | 56 | def parallel_run_timeout(func, args_list, n_processes=8, timeout=5, **kwargs): 57 | pool = Pool(n_processes) 58 | jobs = {} 59 | results = [] 60 | restart = False 61 | 62 | for i, args in enumerate(args_list): 63 | jobs[i] = pool.apply_async(func, args=(args, ), kwds=kwargs) 64 | 65 | total_num = len(args_list) 66 | finished_num = 0 67 | fail_num = 0 68 | for i, r in tqdm(jobs.items()): 69 | try: 70 | finished_num += 1 71 | results.append(r.get(timeout=timeout)) 72 | except TimeoutError as e: 73 | results.append(('exception', TimeoutError)) 74 | logger.info("Timeout args: ") 75 | logger.info(args_list[i]) 76 | fail_num += 1 77 | if fail_num == n_processes and total_num > finished_num: 78 | restart = True 79 | logger.info(f"All processes down, restart, remain {total_num-finished_num}/{total_num}") 80 | break 81 | 82 | pool.close() 83 | pool.terminate() 84 | pool.join() 85 | 86 | if restart: 87 | results.extend(parallel_run_timeout(func, args_list[total_num-finished_num:], n_processes, timeout, **kwargs)) 88 | return results 89 | 90 | 91 | def save_json(file, data_list): 92 | logger.info(f"Saving to {file}") 93 | with open(file, "w") as f: 94 | json.dump(data_list, f) 95 | 96 | 97 | def load_json(file): 98 | logger.info(f"Loading from {file}") 99 | with open(file) as f: 100 | data = json.load(f) 101 | return data 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-adaptive In-context Learning 2 | This repository contains the source code for Self-adaptive In-context Learning, which is proposed in our paper [“Self-adaptive In-context Learning”](https://arxiv.org/abs/2212.10375). If you want to use our method easily, you can use [OpenICL](https://github.com/Shark-NLP/OpenICL), a toolkit for In-context learning. You can also quickly repeat our experiments using our [script](https://github.com/Shark-NLP/OpenICL/blob/main/examples/research_projects/self-adaptive_in-context_learning.ipynb) in it. 3 | 4 | ## Contents 5 | * [Setup](#setup) 6 | * [Reproduce](#reproduce) 7 | * [Usage](#usage) 8 | * [Modules](#modules) 9 | * [Add a New Task](#add-a-new-task) 10 | * [Citation](#citation) 11 | 12 | ## Setup 13 | All required packages can be found in ``requirements.txt``. 14 | You can install them in a new environment with 15 | ```shell 16 | conda create -n adaptive python=3.8 17 | conda activate adaptive 18 | 19 | # The following line to be replaced depending on your cuda version. 20 | pip install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html 21 | pip install -r requirements.txt 22 | 23 | accelerate config # ignore if you don't need multi-gpu 24 | ``` 25 | 26 | Setup WandB for tracking the training status in `scripts/run_xxx.sh`: 27 | ```shell 28 | export WANDB_API_KEY=YOUR_WANDB_API_KEY 29 | export WANDB_PROJECT=YOUR_PROJECT_NAME 30 | export WANDB_ENTITY=YOUR_TEAM_NAME 31 | 32 | root=YOUR_PROJECT_PATH 33 | ``` 34 | 35 | ### Reproduce 36 | ```shell 37 | bash ./scripts/run_mdl.sh 38 | ``` 39 | 40 | ### Usage 41 | Given an index dataset (by default the training set) and an test dataset (by default the test set), we include scripts to run five ICL method under `scripts/`: 42 | - `run_mdl.sh`: based on mdl; 43 | - `run_topk.sh`: based on the similarity of the sentence transfromer embedding; 44 | - `run_random.sh`: random selected in-context examples; 45 | - `run_local_e.sh`: based on entropy; 46 | - `run_prompting.sh`: inference without in-context example; 47 | 48 | The config files can be found in `configs/`. 49 | 50 | ## Modules 51 | 1. `prerank.py`: retrieve examples from training set with topk, random 52 | 2. `retriever.py`: continue to select and rank examples based on the result of prerank.py. 53 | 3. `ppl_inferencer.py`: inference based on the retrived in-context examples. 54 | 55 | ## Add a New Task 56 | Change the task by modify `task_name` argument, and the current available tasks are `sst5, mrpc, qnli, mnli, cmsqa, swag, webqs, geoquery, nl2bash, mtop, break, smcalflow`. 57 | It's easy to add a new task with this repo. You can take the following steps: 58 | 1. Define a dataset wrapper under `src/dataset_readers/dataset_wrapper` to set the text fields. 59 | 2. Add a task template in `src/datasets/instructions.py` 60 | 3. Add a metric method in `src/metrics/eval_datasets.py` 61 | 62 | ## Citation 63 | If you find our work helpful, please cite us: 64 | ``` 65 | @ARTICLE{2022arXiv221210375W, 66 | author = {{Wu}, Zhiyong and {Wang}, Yaoxiang and {Ye}, Jiacheng and {Kong}, Lingpeng}, 67 | title = "{Self-adaptive In-context Learning}", 68 | year = 2022, 69 | eprint = {2212.10375}, 70 | primaryClass = {cs.CL}, 71 | archivePrefix={arXiv}, 72 | adsurl = {https://ui.adsabs.harvard.edu/abs/2022arXiv221210375W}, 73 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /src/utils/dpp_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import numpy 4 | 5 | 6 | def fast_map_dpp(kernel_matrix, max_length, epsilon=1E-10): 7 | """ 8 | fast implementation of the greedy algorithm 9 | :param kernel_matrix: 2-d array 10 | :param max_length: positive int 11 | :param epsilon: small positive scalar 12 | :return: list 13 | reference: https://github.com/laming-chen/fast-map-dpp/blob/master/dpp_test.py 14 | paper: Fast Greedy MAP Inference for Determinantal Point Process to Improve Recommendation Diversity 15 | """ 16 | item_size = kernel_matrix.shape[0] 17 | cis = np.zeros((max_length, item_size)) 18 | di2s = np.copy(np.diag(kernel_matrix)) 19 | selected_items = list() 20 | selected_item = np.argmax(di2s) 21 | selected_items.append(selected_item) 22 | while len(selected_items) < max_length: 23 | k = len(selected_items) - 1 24 | ci_optimal = cis[:k, selected_item] 25 | di_optimal = math.sqrt(di2s[selected_item]) 26 | elements = kernel_matrix[selected_item, :] 27 | eis = (elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal 28 | cis[k, :] = eis 29 | di2s -= np.square(eis) 30 | selected_item = np.argmax(di2s) 31 | if di2s[selected_item] < epsilon: 32 | break 33 | selected_items.append(selected_item) 34 | return selected_items 35 | 36 | 37 | def greedy_map_dpp(kernel_matrix, max_length): 38 | """ 39 | greedy map 40 | reference: http://jgillenw.com/dpp-map.html 41 | paper: Near-Optimal MAP Inference for Determinantal Point Processes 42 | """ 43 | selected_items = [] 44 | item_size = kernel_matrix.shape[0] 45 | U = list(range(0, item_size)) 46 | num_left = item_size 47 | 48 | while len(U) > 0: 49 | scores = np.diag(kernel_matrix) 50 | # Select the max-scoring addition to the chosen set. 51 | max_loc = np.argmax(scores) 52 | max_score = scores[max_loc] 53 | 54 | if max_score < 1 or len(selected_items) == max_length: 55 | break 56 | selected_items.append(U[max_loc]) 57 | del U[max_loc] 58 | 59 | # Compute the new kernel, conditioning on the current selection. 60 | inc_ids = list(range(0, max_loc)) + list(range(max_loc + 1, num_left)) 61 | 62 | kernel_matrix = numpy.linalg.inv( 63 | kernel_matrix + np.diag([1] * (max_loc) + [0] + [1] * (num_left - max_loc - 1))) 64 | num_left -= 1 65 | kernel_matrix = numpy.linalg.inv(kernel_matrix[np.ix_(inc_ids, inc_ids)]) - np.eye(num_left) 66 | 67 | return selected_items 68 | 69 | 70 | if __name__ == "__main__": 71 | import time 72 | 73 | item_size = 100 74 | feature_dimension = 1000 75 | max_length = 50 76 | 77 | scores = np.exp(0.01 * np.random.randn(item_size) + 0.2) 78 | feature_vectors = np.random.randn(item_size, feature_dimension) 79 | 80 | feature_vectors /= np.linalg.norm(feature_vectors, axis=1, keepdims=True) 81 | similarities = np.dot(feature_vectors, feature_vectors.T) 82 | kernel_matrix = scores.reshape((item_size, 1)) * similarities * scores.reshape((1, item_size)) 83 | 84 | t = time.time() 85 | result = fast_map_dpp(kernel_matrix, max_length) 86 | print(result) 87 | print('fast dpp algorithm running time: ' + '\t' + "{0:.4e}".format(time.time() - t)) 88 | 89 | t = time.time() 90 | result = greedy_map_dpp(kernel_matrix, max_length) 91 | print(result) 92 | print('greedy dpp algorithm running time: ' + '\t' + "{0:.4e}".format(time.time() - t)) 93 | -------------------------------------------------------------------------------- /src/dataset_readers/retriever_dsr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from src.dataset_readers.dataset_wrappers import get_dataset_wrapper 5 | 6 | 7 | def get_length(tokenizer, text): 8 | tokenized_example = tokenizer.encode_plus(text, truncation=False, return_tensors='pt') 9 | return int(tokenized_example.input_ids.shape[1]) 10 | 11 | 12 | def set_length(example, **kwargs): 13 | tokenizer = kwargs['tokenizer'] 14 | set_field = kwargs['set_field'] 15 | field_getter = kwargs['field_getter'] 16 | 17 | field_text = field_getter.functions[set_field](example) 18 | example[f'{set_field}_len'] = get_length(tokenizer, field_text) 19 | if set_field not in example: 20 | example[set_field] = field_text 21 | return example 22 | 23 | 24 | def set_field(example, **kwargs): 25 | set_fields = ['C', 'X', 'Y', 'Y_TEXT', 'ALL'] 26 | field_getter = kwargs['field_getter'] 27 | for set_field in set_fields: 28 | field_text = field_getter.functions[set_field](example) 29 | if set_field not in example: 30 | example[set_field] = field_text 31 | return example 32 | 33 | 34 | class RetrieverDatasetReader(torch.utils.data.Dataset): 35 | 36 | def __init__(self, model_name, task_name, index_split, dataset_path, n_tokens=700, 37 | tokenizer=None, index_data_path=None): 38 | self.dataset_wrapper = get_dataset_wrapper(task_name)(dataset_path=dataset_path) 39 | 40 | self.index_dataset = get_dataset_wrapper(task_name)(dataset_split=index_split, dataset_path=index_data_path) 41 | 42 | self.dataset_wrapper.dataset = self.dataset_wrapper.dataset.map( 43 | set_field, 44 | fn_kwargs={'field_getter': self.dataset_wrapper.field_getter} 45 | ) 46 | 47 | self.index_dataset.dataset = self.index_dataset.dataset.map( 48 | set_field, 49 | fn_kwargs={'field_getter': self.index_dataset.field_getter} 50 | ) 51 | 52 | self.n_tokens_in_prompt = n_tokens 53 | self.num_processes = 1 54 | self.process_index = 0 55 | 56 | def __getitem__(self, index): 57 | entry = self.dataset_wrapper[index] 58 | C, X, Y, Y_TEXT, example_list = self.get_ctxs_inputs(entry) 59 | return { 60 | "metadata": {'id': self.num_processes * index + self.process_index, 61 | 'C': C, 62 | 'X': X, 63 | 'Y': Y, 64 | 'Y_TEXT': Y_TEXT, 65 | "examples": example_list} 66 | } 67 | 68 | def shard(self, accelerator): 69 | self.dataset_wrapper.dataset = self.dataset_wrapper.dataset.shard( 70 | num_shards=accelerator.num_processes, 71 | index=accelerator.process_index 72 | ) 73 | self.num_processes = accelerator.num_processes 74 | self.process_index = accelerator.process_index 75 | logging.info(f'accelerator.num_processes={accelerator.num_processes}') 76 | logging.info(f'accelerator.process_index={accelerator.process_index}') 77 | logging.info(f'after shard,len={len(self.dataset_wrapper)}') 78 | 79 | def __len__(self): 80 | return len(self.dataset_wrapper) 81 | 82 | def get_ctxs_inputs(self, entry): 83 | C = self.dataset_wrapper.get_field(entry, 'C') 84 | X = self.dataset_wrapper.get_field(entry, 'X') 85 | Y = self.dataset_wrapper.get_field(entry, 'Y') 86 | Y_TEXT = self.dataset_wrapper.get_field(entry, 'Y_TEXT') 87 | ctxs_candidates = [self.index_dataset.dataset[i[0]] for i in entry['ctxs_candidates']] 88 | example_list = [{'C': i['C'], 'X': i['X'], 'Y': i['Y'], 'Y_TEXT': i['Y_TEXT']} for i in ctxs_candidates] 89 | 90 | return C, X, Y, Y_TEXT, example_list 91 | -------------------------------------------------------------------------------- /src/dataset_readers/inference_dsr.py: -------------------------------------------------------------------------------- 1 | import more_itertools 2 | import numpy as np 3 | import torch 4 | from transformers import AutoTokenizer 5 | 6 | from src.dataset_readers.dataset_wrappers import get_dataset_wrapper 7 | 8 | 9 | def get_length(tokenizer, text): 10 | tokenized_example = tokenizer.encode_plus(text, truncation=False, return_tensors='pt') 11 | return int(tokenized_example.input_ids.shape[1]) 12 | 13 | 14 | def set_length(example, **kwargs): 15 | tokenizer = kwargs['tokenizer'] 16 | set_field = kwargs['set_field'] 17 | field_getter = kwargs['field_getter'] 18 | 19 | field_text = field_getter.functions[set_field](example) 20 | example[f'{set_field}_len'] = get_length(tokenizer, field_text) 21 | if set_field not in example: 22 | example[set_field] = field_text 23 | return example 24 | 25 | 26 | class InferenceDatasetReader(torch.utils.data.Dataset): 27 | 28 | def __init__(self, model_name, task_name, index_split, dataset_path, n_tokens=1600): 29 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 30 | self.tokenizer.pad_token = "<|endoftext|>" 31 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 32 | self.tokenizer.padding_side = "left" 33 | 34 | self.dataset_wrapper = get_dataset_wrapper(task_name)(dataset_path=dataset_path) 35 | self.index_dataset = get_dataset_wrapper(task_name)(dataset_split=index_split) 36 | 37 | self.dataset_wrapper.dataset = self.dataset_wrapper.dataset.map( 38 | set_length, 39 | fn_kwargs={'tokenizer': self.tokenizer, 40 | 'set_field': 'sentence', 41 | 'field_getter': self.dataset_wrapper.field_getter} 42 | ) 43 | self.index_dataset.dataset = self.index_dataset.dataset.map( 44 | set_length, 45 | fn_kwargs={'tokenizer': self.tokenizer, 46 | 'set_field': 'sentence_label', 47 | 'field_getter': self.index_dataset.field_getter} 48 | ) 49 | 50 | self.n_tokens_in_prompt = n_tokens 51 | self.num_processes = 1 52 | self.process_index = 0 53 | 54 | def __getitem__(self, index): 55 | entry = self.dataset_wrapper[index] 56 | question, answer, lengths_list, prompts_list = self.get_ctxs_inputs(entry) 57 | 58 | trunc_prompts_list = self.truncate(question, lengths_list, prompts_list) 59 | prompt_enc_text = "\n".join(trunc_prompts_list) 60 | 61 | enc_text = f"{prompt_enc_text}\n{question}\t{self.dataset_wrapper.postfix}" 62 | tokenized_example = self.tokenizer.encode_plus(enc_text, truncation=False, return_tensors='pt', 63 | add_special_tokens=False) 64 | 65 | entry['id'] = self.num_processes * self.process_index + index 66 | entry['prompt_list'] = trunc_prompts_list 67 | entry['enc_text'] = enc_text 68 | 69 | return { 70 | 'input_ids': tokenized_example.input_ids.squeeze(), 71 | 'attention_mask': tokenized_example.attention_mask.squeeze(), 72 | "metadata": entry 73 | } 74 | 75 | def __len__(self): 76 | return len(self.dataset_wrapper) 77 | 78 | def get_ctxs_inputs(self, entry): 79 | question = self.dataset_wrapper.get_field(entry, 'q') 80 | answer = self.dataset_wrapper.get_field(entry, 'a') 81 | ctx = [self.index_dataset.dataset[i] for i in entry['ctxs']] 82 | # ctx = [self.index_dataset.dataset[i['id']] for i in entry['ctxs']] 83 | prompts_list = [i['qa'] for i in ctx] 84 | lengths_list = [i['qa_len'] for i in ctx] 85 | return question, answer, lengths_list, prompts_list 86 | 87 | def shard(self, accelerator): 88 | self.num_processes = accelerator.num_processes 89 | self.process_index = accelerator.process_index 90 | self.dataset_wrapper.dataset = list( 91 | more_itertools.distribute(accelerator.num_processes, self.dataset_wrapper.dataset)[ 92 | accelerator.process_index]) 93 | 94 | def truncate(self, question, lengths_list, prompts_list): 95 | q_length = get_length(self.tokenizer, question) 96 | max_prompts = np.searchsorted(np.cumsum(lengths_list), self.n_tokens_in_prompt - q_length) 97 | # logger.info(self.n_tokens_in_prompt, max_prompts) 98 | trunc_prompts_list = prompts_list[:max_prompts][::-1] # more similar more close 99 | return trunc_prompts_list 100 | -------------------------------------------------------------------------------- /src/metrics/eval_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | from datasets import load_metric 5 | from src.utils.app import App 6 | from src.datasets.labels import get_mapping_token 7 | 8 | 9 | def renorm(text): 10 | text = text.split("\n")[0] 11 | text = re.sub("[\d]+\#\) ", ";", text) 12 | return text 13 | 14 | 15 | app = App() 16 | 17 | 18 | def add_squad_acc2(file_name, id_list=None): 19 | with open(file_name) as f: 20 | pred_data = json.load(f) 21 | predictions = [{'prediction_text': d['generated'].split('\n')[0], 'id': str(i)} for i, d in 22 | enumerate(pred_data)] 23 | references = [{'answers': {'answer_start': [0], 'text': [d['Y_TEXT']]}, 'id': str(i)} for i, d in 24 | enumerate(pred_data)] 25 | metric = load_metric('squad') 26 | score = metric.compute(predictions=predictions, references=references) 27 | logging.info(f'exact_match={score["exact_match"]}') 28 | logging.info(f'f1={score["f1"]}') 29 | return pred_data, score["f1"] 30 | 31 | 32 | @app.add("sst2") 33 | def add_sst2_acc(file_name, id_list=None, task_name="sst2"): 34 | cor = 0.0 35 | with open(file_name) as f: 36 | label2text = get_mapping_token(task_name) 37 | data = json.load(f) 38 | for line in data: 39 | if id_list is not None and line['idx'] not in id_list: 40 | continue 41 | label = label2text[line['Y']] 42 | pred = line['generated'].split(" ")[-1].strip() 43 | if label == pred: 44 | cor += 1 45 | else: 46 | continue 47 | return data, cor / len(data) 48 | 49 | 50 | @app.add("web_questions") # 用start with? 51 | def add_wq_acc(file_name, id_list=None): 52 | def include(pred, gold): 53 | pred = pred.lower().strip() 54 | gold = gold.lower().strip() 55 | if gold in pred or pred in gold: 56 | return 1 57 | else: 58 | return 0 59 | 60 | with open(file_name) as f: 61 | data = json.load(f) 62 | cor = 0 63 | for line in data: 64 | if id_list is not None and line['id'] not in id_list: 65 | continue 66 | line['acc'] = include(line['generated'], line['Y_TEXT']) 67 | cor += line['acc'] 68 | lenn = len(id_list) if id_list is not None else len(data) 69 | return data, cor / lenn 70 | 71 | 72 | @app.add("rte") 73 | def add_yelp_polarity_acc(file_name, id_list=None): 74 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="rte") 75 | 76 | 77 | @app.add("qnli") 78 | def add_qnli_acc(file_name, id_list=None): 79 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="qnli") 80 | 81 | 82 | @app.add("boolq") 83 | def add_boolq_acc(file_name, id_list=None): 84 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="booq") 85 | 86 | 87 | @app.add("ag_news") 88 | def add_ag_news_acc(file_name, id_list=None): 89 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="ag_news") 90 | 91 | 92 | @app.add("trec") 93 | def add_trec_acc(file_name, id_list=None): 94 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="trec") 95 | 96 | 97 | @app.add("commonsense_qa") 98 | def add_commonsense_qa_acc(file_name, id_list=None): 99 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="commonsense_qa") 100 | 101 | 102 | @app.add("copa") 103 | def add_copa_acc(file_name, id_list=None): 104 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="copa") 105 | 106 | 107 | @app.add("piqa") 108 | def add_copa_acc(file_name, id_list=None): 109 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="piqa") 110 | 111 | 112 | @app.add("mrpc") 113 | def add_mrpc_acc(file_name, id_list=None): 114 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="mrpc") 115 | 116 | 117 | @app.add("yelp_polarity") 118 | def add_yelp_polarity_acc(file_name, id_list=None): 119 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="yelp_polarity") 120 | 121 | 122 | @app.add("imdb") 123 | def add_imdb_acc(file_name, id_list=None): 124 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="imdb") 125 | 126 | 127 | @app.add("sst5") 128 | def add_sst5_acc(file_name, id_list=None): 129 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="sst5") 130 | 131 | 132 | @app.add("mnli") 133 | def add_mnli_acc(file_name, id_list=None): 134 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="mnli") 135 | 136 | 137 | @app.add("snli") 138 | def add_snli_acc(file_name, id_list=None): 139 | return add_sst2_acc(file_name=file_name, id_list=id_list, task_name="snli") 140 | -------------------------------------------------------------------------------- /inferencer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import warnings 5 | import logging 6 | 7 | import hydra 8 | import hydra.utils as hu 9 | import torch 10 | import tqdm 11 | from accelerate import Accelerator 12 | from torch.utils.data import DataLoader 13 | from transformers import AutoTokenizer, GPT2Tokenizer, AutoModelForSeq2SeqLM 14 | 15 | from src.metrics import eval_datasets 16 | from src.utils.cache_util import BufferedJsonWriter, BufferedJsonReader 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Inferencer: 22 | def __init__(self, cfg, accelerator) -> None: 23 | self.task_name = cfg.dataset_reader.task_name 24 | self.dataset_reader = hu.instantiate(cfg.dataset_reader) 25 | self.output_file = cfg.output_file 26 | self.accelerator = accelerator 27 | self.model_name = cfg.model_name 28 | 29 | if cfg.model_name == 'opt-175b': 30 | self.tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b", use_fast=False) 31 | else: 32 | self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name) 33 | self.tokenizer.pad_token = "<|endoftext|>" 34 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 35 | 36 | self.model, self.dataloader = self.init_model_dataloader(cfg) 37 | 38 | def init_model_dataloader(self, cfg): 39 | self.dataset_reader.shard(self.accelerator) 40 | dataloader = DataLoader(self.dataset_reader, batch_size=cfg.batch_size) 41 | if cfg.model_name == 'opt-175b': 42 | model = None 43 | elif 't5' in cfg.model_name: 44 | model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl") 45 | model = self.accelerator.prepare(model) 46 | if hasattr(model, "module"): 47 | model = model.module 48 | else: 49 | model = hu.instantiate(cfg.model).eval() 50 | model = self.accelerator.prepare(model) 51 | if hasattr(model, "module"): 52 | model = model.module 53 | 54 | return model, dataloader 55 | 56 | def forward(self): 57 | if self.accelerator.is_main_process: 58 | dataloader = tqdm.tqdm(self.dataloader) 59 | else: 60 | dataloader = self.dataloader 61 | avg_ice_num = 0 62 | total_num = 0 63 | with BufferedJsonWriter(f"{self.output_file}tmp_{self.accelerator.device}.bin") as buffer: 64 | for i, entry in enumerate(dataloader): 65 | metadata = entry.pop("metadata") 66 | with torch.no_grad(): 67 | res = self.model.generate(input_ids=entry.input_ids, 68 | attention_mask=entry.attention_mask, 69 | eos_token_id=self.dataset_reader.tokenizer.encode("\n")[0], 70 | pad_token_id=self.dataset_reader.tokenizer.pad_token_id, 71 | max_new_tokens=100, 72 | do_sample=False) 73 | a = int(entry.attention_mask.shape[1]) # maxlength??? 74 | for mdata, res_el in zip(metadata, res.tolist()): 75 | mdata['generated'] = self.dataset_reader.tokenizer.decode(res_el[a:], 76 | skip_special_tokens=True) 77 | buffer.write(mdata) 78 | avg_ice_num += len(mdata['prompt_list']) 79 | total_num += 1 80 | 81 | logging.info(f"Average number of in-context examples after truncating is {avg_ice_num / total_num}") 82 | 83 | def write_results(self): 84 | data = [] 85 | for path in glob.glob(f"{self.output_file}tmp_*.bin"): 86 | with BufferedJsonReader(path) as f: 87 | data.extend(f.read()) 88 | for path in glob.glob(f"{self.output_file}tmp_*.bin"): 89 | os.remove(path) 90 | 91 | with open(self.output_file, "w") as f: 92 | json.dump(data, f) 93 | 94 | data, metric = eval_datasets.app[self.task_name](self.output_file) 95 | logger.info(f"metric: {str(metric)}") 96 | with open(self.output_file + '_metric', "w") as f: 97 | logger.info(f'{self.output_file}:{metric}') 98 | json.dump({'metric': metric}, f) 99 | with open(self.output_file, "w") as f: 100 | json.dump(data, f) 101 | 102 | return data 103 | 104 | 105 | @hydra.main(config_path="configs", config_name="inferencer") 106 | def main(cfg): 107 | logger.info(cfg) 108 | accelerator = Accelerator() 109 | inferencer = Inferencer(cfg, accelerator) 110 | 111 | with warnings.catch_warnings(): 112 | warnings.simplefilter("ignore") 113 | inferencer.forward() 114 | accelerator.wait_for_everyone() 115 | if accelerator.is_main_process: 116 | inferencer.write_results() 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: icl 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - ca-certificates=2022.07.19=h06a4308_0 10 | - certifi=2022.9.24=py37h06a4308_0 11 | - cudatoolkit=11.3.1=h2bc3f7f_2 12 | - faiss-gpu=1.7.2=py3.7_h28a55e0_0_cuda11.3 13 | - intel-openmp=2021.4.0=h06a4308_3561 14 | - ld_impl_linux-64=2.38=h1181459_1 15 | - libfaiss=1.7.2=hfc2d529_0_cuda11.3 16 | - libffi=3.3=he6710b0_2 17 | - libgcc-ng=11.2.0=h1234567_1 18 | - libgomp=11.2.0=h1234567_1 19 | - libstdcxx-ng=11.2.0=h1234567_1 20 | - mkl=2021.4.0=h06a4308_640 21 | - mkl-service=2.4.0=py37h7f8727e_0 22 | - mkl_fft=1.3.1=py37hd3c417c_0 23 | - mkl_random=1.2.2=py37h51133e4_0 24 | - ncurses=6.3=h5eee18b_3 25 | - numpy-base=1.21.5=py37ha15fc14_3 26 | - openssl=1.1.1q=h7f8727e_0 27 | - pip=22.1.2=py37h06a4308_0 28 | - python=3.7.13=h12debd9_0 29 | - readline=8.1.2=h7f8727e_1 30 | - setuptools=63.4.1=py37h06a4308_0 31 | - six=1.16.0=pyhd3eb1b0_1 32 | - sqlite=3.39.2=h5082296_0 33 | - tk=8.6.12=h1ccaba5_0 34 | - wheel=0.37.1=pyhd3eb1b0_0 35 | - xz=5.2.6=h5eee18b_0 36 | - zlib=1.2.12=h5eee18b_3 37 | - pip: 38 | - accelerate==0.12.0 39 | - aiohttp==3.8.3 40 | - aiosignal==1.2.0 41 | - aniso8601==9.0.1 42 | - antlr4-python3-runtime==4.9.3 43 | - async-timeout==4.0.2 44 | - asynctest==0.13.0 45 | - attrs==22.1.0 46 | - awesome-slugify==1.6.5 47 | - bitarray==2.6.0 48 | - bitstring==3.1.9 49 | - blessed==1.19.1 50 | - blis==0.7.8 51 | - boto3==1.24.91 52 | - botocore==1.27.91 53 | - cached-property==1.5.2 54 | - cachetools==5.2.0 55 | - catalogue==2.0.8 56 | - chardet==5.0.0 57 | - charset-normalizer==2.1.1 58 | - click==8.1.3 59 | - confection==0.0.1 60 | - configparser==5.3.0 61 | - contexttimer==0.3.3 62 | - cycler==0.11.0 63 | - cymem==2.0.6 64 | - dataclasses==0.6 65 | - dataflows==0.3.16 66 | - datapackage==1.15.2 67 | - datasets==2.3.2 68 | - dill==0.3.5.1 69 | - docker-pycreds==0.4.0 70 | - dppy==0.3.2 71 | - edit-distance==1.0.4 72 | - et-xmlfile==1.1.0 73 | - faiss==1.5.3 74 | - filelock==3.8.0 75 | - flask==2.2.2 76 | - flask-restful==0.3.9 77 | - fonttools==4.37.3 78 | - frozenlist==1.3.1 79 | - fsspec==2022.8.2 80 | - future==0.18.2 81 | - gitdb==4.0.9 82 | - gitpython==3.1.27 83 | - greenlet==1.1.3.post0 84 | - huggingface-hub==0.9.1 85 | - hydra-core==1.2.0 86 | - idna==3.4 87 | - ijson==3.1.4 88 | - importlib-metadata==4.12.0 89 | - importlib-resources==5.10.0 90 | - inflect==6.0.0 91 | - inquirer==2.10.0 92 | - isodate==0.6.1 93 | - itsdangerous==2.1.2 94 | - jinja2==3.1.2 95 | - jmespath==1.0.1 96 | - joblib==1.2.0 97 | - jsonlines==3.1.0 98 | - jsonpointer==2.3 99 | - jsonschema==4.16.0 100 | - kiwisolver==1.4.4 101 | - kvfile==0.0.13 102 | - langcodes==3.3.0 103 | - linear-tsv==1.1.0 104 | - markupsafe==2.1.1 105 | - matplotlib==3.5.3 106 | - more-itertools==8.14.0 107 | - multidict==6.0.2 108 | - multiprocess==0.70.13 109 | - murmurhash==1.0.8 110 | - networkx==2.6.3 111 | - nltk==3.7 112 | - numpy==1.21.6 113 | - omegaconf==2.2.3 114 | - openpyxl==3.0.10 115 | - overrides==7.0.0 116 | - packaging==21.3 117 | - pandas==1.3.5 118 | - pathtools==0.1.2 119 | - pathy==0.6.2 120 | - pillow==9.2.0 121 | - pkgutil-resolve-name==1.3.10 122 | - preshed==3.0.7 123 | - progressbar==2.5 124 | - promise==2.3 125 | - protobuf==3.19.0 126 | - psutil==5.9.2 127 | - py-dataflow==0.0.6 128 | - pyarrow==9.0.0 129 | - pybloom-live==4.0.0 130 | - pydantic==1.9.2 131 | - pyparsing==3.0.9 132 | - pyrsistent==0.18.1 133 | - python-dateutil==2.8.2 134 | - python-editor==1.0.4 135 | - pytz==2022.2.1 136 | - pyyaml==6.0 137 | - rank-bm25==0.2.2 138 | - readchar==4.0.3 139 | - regex==2022.9.13 140 | - requests==2.28.1 141 | - responses==0.18.0 142 | - rfc3986==2.0.0 143 | - s3transfer==0.6.0 144 | - sacremoses==0.0.53 145 | - scikit-learn==1.0.2 146 | - scipy==1.7.3 147 | - sentence-transformers==2.2.2 148 | - sentencepiece==0.1.97 149 | - sentry-sdk==1.9.9 150 | - shortuuid==1.0.9 151 | - sklearn==0.0 152 | - smart-open==5.2.1 153 | - smmap==5.0.0 154 | - spacy==3.4.1 155 | - spacy-legacy==3.0.10 156 | - spacy-loggers==1.0.3 157 | - sqlalchemy==1.4.42 158 | - srsly==2.4.4 159 | - subprocess32==3.5.4 160 | - tableschema==1.20.2 161 | - tableschema-sql==1.3.2 162 | - tabulate==0.9.0 163 | - tabulator==1.53.5 164 | - termcolor-whl==1.1.2 165 | - thinc==8.1.2 166 | - threadpoolctl==3.1.0 167 | - tokenizers==0.12.1 168 | - tqdm==4.64.1 169 | - transformers==4.21.1 170 | - typer==0.4.2 171 | - typing-extensions==4.1.1 172 | - unicodecsv==0.14.1 173 | - unidecode==0.04.21 174 | - urllib3==1.26.12 175 | - wandb==0.12.7 176 | - wasabi==0.10.1 177 | - wcwidth==0.2.5 178 | - werkzeug==2.2.2 179 | - xlrd==2.0.1 180 | - xmljson==0.2.1 181 | - xxhash==3.0.0 182 | - yarl==1.8.1 183 | - yaspin==2.2.0 184 | - zipp==3.8.1 185 | prefix: /mnt/cache/wangyaoxiang/anaconda3/envs/icl 186 | -------------------------------------------------------------------------------- /src/models/model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from transformers import AutoModelForCausalLM 6 | from src.utils.calculate import entropy 7 | 8 | 9 | def no_init(loading_code): 10 | ''' 11 | no_init_weights is used in from_pretrained to speed up loading large models. 12 | However, torch-built-in modules like torch.nn.Linear are heavily used in models of transformers, 13 | while its weights initialization cannot be disabled by no_init_weights. 14 | ''' 15 | 16 | def dummy(self): 17 | return 18 | 19 | modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm] 20 | original = {} 21 | for mod in modules: 22 | original[mod] = mod.reset_parameters 23 | mod.reset_parameters = dummy 24 | 25 | result = loading_code() 26 | for mod in modules: 27 | mod.reset_parameters = original[mod] 28 | 29 | return result 30 | 31 | 32 | def generate(self, batch, span=False): 33 | if span: 34 | # 从input中取出字符级别的span,并把这部分去掉 35 | for i, input_text in enumerate(batch): 36 | id_begin = input_text.rfind('.') 37 | batch[i] = input_text[:id_begin] 38 | 39 | tokenized_inputs = \ 40 | self.tokenizer.batch_encode_plus(batch, truncation=True, max_length=self.n_tokens, return_tensors='pt', 41 | add_special_tokens=False, padding='longest').to( 42 | self.model.device) # truncation 43 | res = self.model.generate(input_ids=tokenized_inputs.input_ids, 44 | attention_mask=tokenized_inputs.attention_mask, 45 | # eos_token_id=self.tokenizer.encode("\n")[0], # prompt中是以\n表示结束 46 | pad_token_id=self.tokenizer.pad_token_id, 47 | max_new_tokens=100, 48 | do_sample=False) 49 | a = tokenized_inputs.attention_mask.shape[1] 50 | generated = self.tokenizer.batch_decode(res[:, a:], skip_special_tokens=True) 51 | return generated 52 | 53 | 54 | def evaluate(self, input_texts: List[str], span=False) -> Tuple: 55 | 56 | def span_char2span_token(encoded, x_span_char0): 57 | x_span_token0 = [] 58 | for ii in range(len(x_span_char0)): 59 | x_span_token_begin = -1 60 | x_span_token_end = -1 61 | for token_index in range(len(encoded.tokens(ii))): 62 | this_token = encoded.word_ids(ii)[token_index] 63 | if (not this_token == None): 64 | # print('###########################') 65 | # print(token_index) 66 | cur_span = encoded.token_to_chars(ii, token_index) 67 | # print(encoded.token_to_chars(token_index)) 68 | if x_span_token_begin == -1 and cur_span.start >= x_span_char0[ii][ 69 | 0] - 1: 70 | x_span_token_begin = token_index 71 | if cur_span.end <= x_span_char0[ii][1] + 1: 72 | x_span_token_end = token_index 73 | 74 | x_span_token0.append([x_span_token_begin, x_span_token_end]) 75 | return x_span_token0 76 | 77 | def span_char2span_token2(encoded, x_span_char0): 78 | decode_lens = (encoded["input_ids"] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy() 79 | x_span_token0 = [] 80 | for ii in range(len(x_span_char0)): 81 | decode_len = decode_lens[ii] 82 | decode_list = [self.tokenizer.decode(encoded['input_ids'][ii][jj]) for jj in range(decode_len)] 83 | decode_list_len = [len(t) for t in decode_list] 84 | span_len = x_span_char0[ii][1] - x_span_char0[ii][0] 85 | sum_len = 0 86 | x_span_token_begin = len(decode_list_len) - 1 87 | x_span_token_end = len(decode_list_len) 88 | for k in range(len(decode_list_len) - 1, -1, -1): 89 | sum_len += decode_list_len[k] 90 | if sum_len > span_len: 91 | x_span_token_begin = k + 1 92 | break 93 | x_span_token0.append([x_span_token_begin, x_span_token_end]) 94 | return x_span_token0 95 | 96 | if span: 97 | span_char = [] 98 | for i, input_text in enumerate(input_texts): 99 | id_begin = input_text.rfind('.') 100 | tmp = input_text[id_begin:].split(' ') 101 | input_texts[i] = input_text[:id_begin] 102 | span_char.append([int(tmp[1]), int(tmp[2])]) 103 | 104 | inputs = self.tokenizer(input_texts, padding=True, return_tensors='pt', truncation=True) 105 | 106 | if self.model_name == 'opt-175b': 107 | import requests 108 | import json 109 | URL = "http://10.140.0.230:6010/completions" 110 | headers = { 111 | "Content-Type": "application/json; charset=UTF-8" 112 | } 113 | pyload = {"prompt": input_texts, "max_tokens": 0, "echo": True} 114 | response = json.loads( 115 | requests.post(URL, data=json.dumps(pyload), headers=headers, proxies={"https": "", "http": ""}).text) 116 | 117 | lens = np.array([len(r['logprobs']['tokens']) for r in response['choices']]) 118 | loss_lens = np.array([len(r['logprobs']['token_logprobs']) for r in response['choices']]) 119 | 120 | loss = [r['logprobs']['token_logprobs'] for r in response['choices']] 121 | 122 | max_len = loss_lens.max() 123 | loss_pad = list(map(lambda l: l + [0] * (max_len - len(l)), loss)) 124 | 125 | loss = -np.array(loss_pad) 126 | 127 | loss = torch.tensor(loss) 128 | if span: 129 | span_token = span_char2span_token2(inputs, span_char) 130 | if span: 131 | mask = torch.zeros_like(loss) # [batch,seqlen] 132 | for i in range(len(mask)): 133 | for j in range(len(mask[i])): 134 | if span_token[i][0] <= j <= span_token[i][1]: 135 | mask[i][j] = 1 136 | 137 | loss = loss * mask 138 | lens = np.array([(x_span[1] - x_span[0]) for x_span in span_token]) 139 | 140 | ce_loss = loss.sum(-1).cpu().detach().numpy() # -log(p(y)) 141 | return ce_loss, lens 142 | 143 | if span: 144 | if 'opt' in self.model_name: 145 | span_token = span_char2span_token2(inputs, span_char) 146 | else: 147 | try: 148 | span_token = span_char2span_token(inputs, span_char) 149 | except: 150 | span_token = span_char2span_token2(inputs, span_char) 151 | 152 | inputs = {k: v.to(self.model.device) for k, v in inputs.items()} 153 | 154 | 155 | outputs = self.model(**inputs) 156 | shift_logits = outputs.logits[..., :-1, :].contiguous() 157 | # note here we assume padding is performed on the right, left padding token will affect position_id in gpt2 158 | shift_labels = inputs["input_ids"][..., 1:].contiguous() 159 | loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.tokenizer.pad_token_id) 160 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view( 161 | shift_labels.size()) 162 | 163 | lens = (inputs["input_ids"] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy() 164 | 165 | if span: 166 | mask = torch.zeros_like(shift_labels) # [batch,seqlen] 167 | for i in range(len(mask)): 168 | for j in range(len(mask[i])): 169 | if span_token[i][0] <= j <= span_token[i][1]: 170 | mask[i][j] = 1 171 | loss = loss * mask 172 | lens = np.array([(x_span[1] - x_span[0]) for x_span in span_token]) 173 | ce_loss = loss.sum(-1).cpu().detach().numpy() # -log(p(y)) 174 | 175 | return ce_loss, lens 176 | 177 | 178 | def get_model(**kwargs): 179 | return no_init(lambda: AutoModelForCausalLM.from_pretrained(**kwargs)) 180 | 181 | 182 | def get_score(self, batch_labels, method='mdl', span=False, prior_loss_list=None): 183 | loss_list = [] # [labels,batch size] 184 | for i, batch in enumerate(batch_labels): 185 | with torch.no_grad(): 186 | ce_loss, lens = evaluate(self, batch, span=span) 187 | avg_loss = (ce_loss / lens).tolist() 188 | loss_list.append(avg_loss) 189 | if prior_loss_list is not None: 190 | loss_list = np.array(loss_list) - np.array(prior_loss_list) 191 | probs = np.exp(-np.array(loss_list)) # [labels,dataset size] 192 | 193 | normalized_probs = probs / probs.sum(0, keepdims=True) 194 | 195 | if method == 'mdl': 196 | neg_entropy = -entropy(normalized_probs, label_dim=0) 197 | return neg_entropy 198 | elif method == "entropy": 199 | return entropy(normalized_probs, label_dim=0) 200 | -------------------------------------------------------------------------------- /ppl_inferencer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import logging 4 | import hydra 5 | import numpy as np 6 | import tqdm 7 | import random 8 | from accelerate import Accelerator 9 | 10 | from src.utils.cache_util import BufferedJsonWriter 11 | from src.datasets.labels import get_mapping_token 12 | 13 | from src.utils.calculate import dict_list2list_dict, transform 14 | from inferencer import Inferencer 15 | from src.datasets.instructions import * 16 | from src.models.model import evaluate, generate 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | # 改为继承inferencer 22 | class PPLInferencer(Inferencer): 23 | 24 | def __init__(self, cfg, accelerator) -> None: 25 | super(PPLInferencer, self).__init__(cfg, accelerator) 26 | 27 | self.output_file = cfg.output_file 28 | self.instruction_template = cfg.instruction_template 29 | self.span = cfg.span 30 | self.task_type = get_task_type(self.task_name) 31 | self.n_tokens = cfg.n_tokens 32 | self.printonce = 0 33 | self.calibrate = cfg.calibrate 34 | self.prior_no = cfg.prior_no 35 | self.reverse_label = cfg.reverse_label 36 | 37 | instructions = get_template(self.task_name, self.instruction_template) 38 | if instructions is not None: 39 | self.labels = [y for y in instructions.keys()] # int 40 | self.example_instruction = {label: instructions[label]['example_instruction'] for label in self.labels} 41 | self.prompting_instruction = {label: instructions[label]['prompting_instruction'] for label in self.labels} 42 | if self.task_type == "QA": 43 | self.tokenizer.padding_side = "left" 44 | 45 | 46 | def forward(self): 47 | 48 | def build_batch(_metadata, y, generated=None, prior=False): 49 | 50 | _metadata = transform(_metadata) 51 | 52 | if self.reverse_label: 53 | for m in _metadata: 54 | for e in m['examples']: 55 | e['Y'] = (e['Y'] + 1) % len(self.labels) 56 | 57 | choice = self.task_type == "CHOICE" 58 | choice_sep = get_choice_sep(self.task_name) 59 | if generated is None: 60 | return [build_instruction(x=m['X'], c=m['C'], 61 | e=m['examples'], 62 | y_text="", 63 | tokenizer=self.tokenizer, 64 | instruction=self.prompting_instruction[y], 65 | e_instruction=self.example_instruction, 66 | need_span_ids=self.span, 67 | max_len=self.n_tokens, 68 | prior=prior, prior_no=self.prior_no, 69 | choice=choice, choice_sep=choice_sep)[0] for m in _metadata] 70 | else: 71 | return [build_instruction(x=m['X'], c=m['C'], 72 | e=m['examples'], 73 | y_text=y_text, 74 | tokenizer=self.tokenizer, 75 | instruction=self.prompting_instruction[y], 76 | e_instruction=self.example_instruction, need_span_ids=self.span, 77 | max_len=self.n_tokens)[0] for m, y_text in zip(_metadata, generated)] 78 | 79 | if self.accelerator.is_main_process: 80 | dataloader = tqdm.tqdm(self.dataloader) 81 | else: 82 | dataloader = self.dataloader 83 | 84 | mapping_token = get_mapping_token(self.task_name) 85 | tmpfile = f"{self.output_file}tmp_{self.accelerator.device}.bin" 86 | if os.path.exists(tmpfile): 87 | os.remove(tmpfile) 88 | with BufferedJsonWriter(f"{self.output_file}tmp_{self.accelerator.device}.bin") as buffer: 89 | 90 | for ii, entry in enumerate(dataloader): 91 | metadata = entry.pop("metadata") 92 | 93 | batch_labels = [build_batch(metadata, label) for label in self.labels] # label:int 94 | if self.calibrate: 95 | batch_labels_prior = [build_batch(metadata, label, prior=True) for label in self.labels] 96 | prior_loss_list = [] 97 | for batch in batch_labels_prior: 98 | with torch.no_grad(): 99 | prior_ce_loss, prior_lens = evaluate(self, batch, span=self.span) 100 | avg_prior_loss = (prior_ce_loss / prior_lens).tolist() 101 | prior_loss_list.append(avg_prior_loss) 102 | 103 | if self.printonce > 0: 104 | self.printonce -= 1 105 | logger.info('batchlabels', batch_labels) 106 | 107 | if self.task_type != "QA": 108 | loss_list = [] # [labels,batch size] 109 | lens_list = [] 110 | 111 | for i, batch in enumerate(batch_labels): 112 | with torch.no_grad(): 113 | ce_loss, lens = evaluate(self, batch, span=self.span) 114 | avg_loss = (ce_loss / lens).tolist() 115 | loss_list.append(avg_loss) 116 | lens_list.append(lens) 117 | if self.calibrate: 118 | preds_prior = np.array(prior_loss_list).argmin(axis=0) 119 | preds_prior = [mapping_token[pred] for pred in preds_prior] 120 | 121 | prior_probs = np.exp(-np.array(prior_loss_list)) 122 | prior_normalized_probs = prior_probs / prior_probs.sum(0, keepdims=True) 123 | prior_probs = np.transpose(prior_normalized_probs).tolist() 124 | 125 | loss_list = np.array(loss_list) - np.array(prior_loss_list) 126 | 127 | preds = np.array(loss_list).argmin(axis=0) # [batch size] 128 | preds = [mapping_token[pred] for pred in preds] 129 | 130 | probs = np.exp(-np.array(loss_list)) 131 | normalized_probs = probs / probs.sum(0, keepdims=True) 132 | probs = np.transpose(normalized_probs).tolist() 133 | else: 134 | batch = batch_labels[0] 135 | with torch.no_grad(): 136 | preds = generate(self, batch, span=self.span) 137 | 138 | batch_labels = [build_batch(metadata, label, preds) for label in self.labels] 139 | loss_list = [] 140 | for batch in batch_labels: 141 | with torch.no_grad(): 142 | ce_loss, lens = evaluate(self, batch, span=self.span) 143 | avg_loss = (ce_loss / lens).tolist() 144 | loss_list.append(avg_loss) 145 | probs = np.exp(-np.array(loss_list)) 146 | probs = np.transpose(probs).tolist() 147 | 148 | metadata.pop("examples") 149 | 150 | metadata_tmp = dict_list2list_dict(metadata) 151 | 152 | if self.calibrate and self.task_type != "QA": 153 | for mdata, pred_text, pred_text_prior, prob, prior_prob in zip(metadata_tmp, preds, preds_prior, 154 | probs, prior_probs): 155 | 156 | mdata['generated'] = pred_text 157 | mdata['prior'] = pred_text_prior 158 | mdata['probs'] = prob 159 | mdata['prior_prob'] = prior_prob 160 | for key in mdata.keys(): 161 | if torch.is_tensor(mdata[key]): 162 | mdata[key] = mdata[key].tolist() 163 | buffer.write(mdata) 164 | else: 165 | for mdata, pred_text, prob in zip(metadata_tmp, preds, probs): 166 | 167 | mdata['generated'] = pred_text 168 | mdata['probs'] = prob 169 | for key in mdata.keys(): 170 | if torch.is_tensor(mdata[key]): 171 | mdata[key] = mdata[key].tolist() 172 | buffer.write(mdata) 173 | 174 | 175 | @hydra.main(config_path="configs", config_name="ppl_inferencer") 176 | def main(cfg): 177 | logger.info(cfg) 178 | random.seed(cfg.rand_seed) 179 | np.random.seed(cfg.rand_seed) 180 | if not cfg.overwrite: 181 | if os.path.exists(cfg.output_file): 182 | logger.info(f'{cfg.output_file} already exists,skip') 183 | return 184 | accelerator = Accelerator() 185 | inferencer = PPLInferencer(cfg, accelerator) 186 | 187 | with warnings.catch_warnings(): 188 | warnings.simplefilter("ignore") 189 | inferencer.forward() 190 | accelerator.wait_for_everyone() 191 | if accelerator.is_main_process: 192 | inferencer.write_results() 193 | 194 | 195 | if __name__ == "__main__": 196 | main() 197 | -------------------------------------------------------------------------------- /retriever.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import warnings 5 | import logging 6 | from typing import Optional, Dict, List 7 | import hydra 8 | import tqdm 9 | from accelerate import Accelerator 10 | 11 | from src.utils.cache_util import BufferedJsonWriter, BufferedJsonReader 12 | from inferencer import Inferencer 13 | from src.datasets.instructions import * 14 | from src.models.model import evaluate, get_score, generate 15 | from src.utils.calculate import transform, get_permutations 16 | 17 | import random 18 | import numpy as np 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class Retriever(Inferencer): 24 | 25 | def __init__(self, cfg, accelerator) -> None: 26 | super(Retriever, self).__init__(cfg, accelerator) 27 | self.window = cfg.window 28 | self.input_file = cfg.dataset_reader.dataset_path 29 | self.output_file = cfg.output_file 30 | self.method = cfg.method 31 | self.num_ice = cfg.num_ice 32 | self.instruction_template = cfg.instruction_template 33 | self.force_topk = cfg.force_topk 34 | self.span = cfg.span 35 | self.n_tokens = cfg.n_tokens 36 | self.printonce = 0 37 | self.all_permutation = cfg.all_permutation 38 | self.calibrate = cfg.calibrate 39 | self.sort = cfg.sort 40 | self.use_rand_pool = cfg.use_rand_pool 41 | self.rand_pool = None 42 | self.prior_no = cfg.prior_no 43 | 44 | instructions = get_template(self.task_name, self.instruction_template) 45 | if instructions is not None: 46 | self.labels = [y for y in instructions.keys()] # [int] QA:[0] 47 | self.example_instruction = {label: instructions[label]['example_instruction'] for label in self.labels} 48 | self.prompting_instruction = {label: instructions[label]['prompting_instruction'] for label in self.labels} 49 | 50 | self.task_type = get_task_type(self.task_name) 51 | if self.task_type == "QA": 52 | self.tokenizer.padding_side = "left" 53 | 54 | 55 | def forward(self): 56 | 57 | if self.accelerator.is_main_process: 58 | dataloader = tqdm.tqdm(self.dataloader) 59 | else: 60 | dataloader = self.dataloader 61 | tmpfile = f"{self.output_file}tmp_{self.accelerator.device}.bin" 62 | if os.path.exists(tmpfile): 63 | os.remove(tmpfile) 64 | with BufferedJsonWriter(f"{self.output_file}tmp_{self.accelerator.device}.bin") as buffer: 65 | 66 | for i, entry in enumerate(dataloader): 67 | metadata = entry.pop("metadata") 68 | if self.printonce > 0: 69 | self.printonce -= 1 70 | metadata = transform(metadata) 71 | if self.printonce > 0: 72 | self.printonce -= 1 73 | 74 | ctxs = [self.retrieve(pool=m["examples"], num=self.num_ice, query=m, 75 | method=self.method) 76 | for m in metadata] 77 | 78 | for mdata, selected_idx in zip(metadata, ctxs): 79 | mdata.pop("examples") 80 | mdata['selected_idxs'] = selected_idx 81 | for key in mdata.keys(): 82 | if torch.is_tensor(mdata[key]): 83 | mdata[key] = mdata[key].tolist() 84 | buffer.write(mdata) 85 | 86 | def retrieve(self, pool: List[Dict], num: int, query: Optional[Dict] = None, method: Optional[str] = None): 87 | selected_idxs = self.instance_level_lm_score(pool=pool, num=num, query=query, method=method) 88 | return selected_idxs 89 | 90 | def instance_level_lm_score(self, pool: List[Dict], num: int, query, method: str = 'mdl') \ 91 | -> List: 92 | window = self.window # number of candidates 93 | 94 | if self.use_rand_pool: 95 | if self.rand_pool is None: 96 | self.rand_pool = [np.random.choice(list(range(len(pool))), size=num, replace=False).tolist() for _ in 97 | range(window)] 98 | all_candidate_idx = self.rand_pool 99 | 100 | elif self.all_permutation: 101 | all_candidate_idx = get_permutations(num) 102 | elif self.sort: 103 | all_candidate_idx = [sorted(np.random.choice(list(range(len(pool))), size=num, replace=False).tolist()) 104 | for _ in range(window)] 105 | else: 106 | all_candidate_idx = [np.random.choice(list(range(len(pool))), size=num, replace=False).tolist() for _ in 107 | range(window)] 108 | if self.force_topk: 109 | new = [i for i in range(num)] 110 | all_candidate_idx.pop(0) 111 | all_candidate_idx.append(new) 112 | if window == 1: 113 | return all_candidate_idx[0] 114 | 115 | in_context_examples = [[pool[i] for i in candidates_idx] for candidates_idx in all_candidate_idx] 116 | 117 | if self.task_type == "QA": 118 | batch = [build_instruction(x=query['X'], c=query['C'], e=e, y_text="", 119 | instruction=self.prompting_instruction[0], 120 | tokenizer=self.tokenizer, 121 | e_instruction=self.example_instruction, need_span_ids=self.span, 122 | max_len=self.n_tokens)[0] 123 | for e in in_context_examples] 124 | generated = generate(self, batch, span=self.span) 125 | choice = self.task_type == "CHOICE" 126 | choice_sep = get_choice_sep(self.task_name) 127 | 128 | batch_labels = [[build_instruction(x=query['X'], c=query['C'], e=e, 129 | y_text=None if self.task_type != "QA" else generated[i], 130 | instruction=self.prompting_instruction[label], 131 | tokenizer=self.tokenizer, 132 | e_instruction=self.example_instruction, need_span_ids=self.span, 133 | max_len=self.n_tokens,choice=choice,choice_sep=choice_sep)[0] 134 | for i, e in enumerate(in_context_examples)] 135 | for label in self.labels] # label:int 136 | if self.calibrate and self.task_type != 'QA': 137 | batch_labels_prior = [[build_instruction(x=query['X'], c=query['C'], e=e, 138 | y_text=None, 139 | instruction=self.prompting_instruction[label], 140 | tokenizer=self.tokenizer, 141 | e_instruction=self.example_instruction, need_span_ids=self.span, 142 | max_len=self.n_tokens, prior=True, prior_no=self.prior_no)[0] 143 | for i, e in enumerate(in_context_examples)] 144 | for label in self.labels] # label:int 145 | prior_loss_list = [] 146 | for batch in batch_labels_prior: 147 | with torch.no_grad(): 148 | prior_ce_loss, prior_lens = evaluate(self, batch, span=self.span) 149 | avg_prior_loss = (prior_ce_loss / prior_lens).tolist() 150 | prior_loss_list.append(avg_prior_loss) 151 | scores = get_score(self, batch_labels, method=method, span=self.span, prior_loss_list=prior_loss_list) 152 | else: 153 | scores = get_score(self, batch_labels, method=method, span=self.span) 154 | 155 | selected_idxs = all_candidate_idx[scores.argmax()] 156 | return selected_idxs # list int 157 | 158 | def write_results(self): 159 | data = [] 160 | for path in glob.glob(f"{self.output_file}tmp_*.bin"): 161 | print(path) 162 | with BufferedJsonReader(path) as f: 163 | data.extend(f.read()) 164 | for path in glob.glob(f"{self.output_file}tmp_*.bin"): 165 | os.remove(path) 166 | 167 | 168 | outputfile1 = self.input_file 169 | with open(outputfile1, 'r') as load_f: 170 | load_data = json.load(load_f) 171 | data.sort(key=lambda x: int(x['id'])) 172 | 173 | ctxs_list = [np.array(load_data[i]["ctxs_candidates"]).reshape(-1)[data[i]["selected_idxs"]] 174 | for i in range(len(load_data))] 175 | 176 | with open(self.output_file, "w") as f: 177 | for d, ctxs in zip(load_data, ctxs_list): 178 | d["ctxs"] = ctxs.tolist() 179 | json.dump(load_data, f) 180 | 181 | return data 182 | 183 | 184 | @hydra.main(config_path="configs", config_name="retriever") 185 | def main(cfg): 186 | logger.info(cfg) 187 | if not cfg.overwrite: 188 | if os.path.exists(cfg.output_file): 189 | logger.info(f'{cfg.output_file} already exists,skip') 190 | return 191 | random.seed(cfg.rand_seed) 192 | np.random.seed(cfg.rand_seed) 193 | accelerator = Accelerator() 194 | retriever = Retriever(cfg, accelerator) 195 | 196 | with warnings.catch_warnings(): 197 | warnings.simplefilter("ignore") 198 | retriever.forward() 199 | accelerator.wait_for_everyone() 200 | if accelerator.is_main_process: 201 | retriever.write_results() 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /prerank.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from collections import defaultdict 4 | 5 | import faiss 6 | import hydra 7 | import hydra.utils as hu 8 | import numpy as np 9 | import random 10 | import torch 11 | import tqdm 12 | from sentence_transformers import SentenceTransformer 13 | import os 14 | import datetime 15 | 16 | 17 | from datasets import load_dataset 18 | from sklearn.metrics.pairwise import cosine_similarity 19 | from torch.utils.data import DataLoader 20 | from dppy.finite_dpps import FiniteDPP 21 | 22 | from src.utils.collators import DataCollatorWithPaddingAndCuda 23 | from src.dataset_readers.prerank_dsr import PrerankDatasetReader 24 | from src.utils.dpp_map import fast_map_dpp 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class PreRank: 30 | def __init__(self, cfg) -> None: 31 | self.cuda_device = cfg.cuda_device 32 | 33 | self.retriever_model = SentenceTransformer(cfg.retriever_model).to( 34 | self.cuda_device) if cfg.retriever_model != 'none' else None 35 | 36 | self.retriever_model.eval() 37 | 38 | self.dataset_reader = PrerankDatasetReader(task_name=cfg.dataset_reader.task_name, 39 | field=cfg.dataset_reader.field, 40 | dataset_path=cfg.dataset_reader.dataset_path, 41 | dataset_split=cfg.dataset_reader.dataset_split, 42 | tokenizer=self.retriever_model.tokenizer) 43 | 44 | 45 | co = DataCollatorWithPaddingAndCuda(tokenizer=self.dataset_reader.tokenizer, 46 | device=self.cuda_device) 47 | 48 | self.dataloader = DataLoader(self.dataset_reader, batch_size=cfg.batch_size, collate_fn=co) 49 | 50 | self.output_file = cfg.output_file 51 | self.num_candidates = cfg.num_candidates 52 | self.num_ice = cfg.num_ice 53 | self.is_train = cfg.dataset_reader.dataset_split == "train" 54 | self.dpp_sampling = cfg.dpp_sampling 55 | self.scale_factor = cfg.scale_factor 56 | self.dpp_topk = cfg.dpp_topk 57 | self.mode = "cand_selection" 58 | self.method = cfg.method 59 | self.vote_k_idxs = None 60 | self.vote_k_k = cfg.vote_k_k 61 | 62 | self.index_reader = PrerankDatasetReader(task_name=cfg.index_reader.task_name, 63 | field=cfg.index_reader.field, 64 | dataset_path=cfg.index_reader.dataset_path, 65 | dataset_split=cfg.index_reader.dataset_split, 66 | tokenizer=self.retriever_model.tokenizer) 67 | if self.method != "random": 68 | self.index = self.create_index(cfg) 69 | 70 | def create_index(self, cfg): 71 | logger.info("building index...") 72 | starttime = datetime.datetime.now() 73 | co = DataCollatorWithPaddingAndCuda(tokenizer=self.index_reader.tokenizer, device=self.cuda_device) 74 | dataloader = DataLoader(self.index_reader, batch_size=cfg.batch_size, collate_fn=co) 75 | 76 | index = faiss.IndexIDMap(faiss.index_cpu_to_all_gpus(faiss.IndexFlatIP(768))) 77 | res_list = self.forward(dataloader) 78 | 79 | id_list = np.array([res['metadata']['id'] for res in res_list]) 80 | embed_list = np.stack([res['embed'] for res in res_list]) 81 | if self.method == 'votek': 82 | self.vote_k_idxs = self.vote_k_select(embeddings=embed_list, select_num=self.num_candidates, 83 | k=self.vote_k_k,overlap_threshold=1) 84 | index.add_with_ids(embed_list, id_list) 85 | cpu_index = faiss.index_gpu_to_cpu(index) 86 | faiss.write_index(cpu_index, cfg.index_file) 87 | endtime = datetime.datetime.now() 88 | logger.info(f"end building index, size {len(self.index_reader)}, time: {(endtime-starttime).seconds} seconds") 89 | return index 90 | 91 | def forward(self, dataloader, **kwargs): 92 | res_list = [] 93 | logger.info(f"Totoal number of batches: {len(dataloader)}") 94 | for i, entry in enumerate(dataloader): 95 | with torch.no_grad(): 96 | if i % 500 == 0: 97 | logger.info(f"finish {str(i)} batches") 98 | metadata = entry.pop("metadata") 99 | raw_text = self.retriever_model.tokenizer.batch_decode(entry['input_ids'], skip_special_tokens=True) 100 | res = self.retriever_model.encode(raw_text, show_progress_bar=False, **kwargs) 101 | res_list.extend([{"embed": r, "metadata": m} for r, m in zip(res, metadata)]) 102 | return res_list 103 | 104 | def knn_search(self, entry, num_candidates=1, num_ice=1): 105 | embed = np.expand_dims(entry['embed'], axis=0) 106 | near_ids = self.index.search(embed, max(num_candidates, num_ice) + 1)[1][0].tolist() 107 | near_ids = near_ids[1:] if self.is_train else near_ids 108 | return near_ids[:num_ice], [[i] for i in near_ids[:num_candidates]] 109 | 110 | def random_search(self, num_candidates=1, num_ice=1): 111 | rand_ids = np.random.choice(list(range(len(self.index_reader))), size=num_candidates, replace=False).tolist() 112 | return rand_ids[:num_ice], [[i] for i in rand_ids[:num_candidates]] 113 | 114 | def get_kernel(self, embed, candidates): 115 | near_reps = np.stack([self.index.index.reconstruct(i) for i in candidates], axis=0) 116 | # normalize first 117 | embed = embed / np.linalg.norm(embed) 118 | near_reps = near_reps / np.linalg.norm(near_reps, keepdims=True, axis=1) 119 | 120 | rel_scores = np.matmul(embed, near_reps.T)[0] 121 | rel_scores = (rel_scores + 1) / 2 122 | # to balance relevance and diversity 123 | rel_scores = np.exp(rel_scores / (2 * self.scale_factor)) 124 | 125 | sim_matrix = np.matmul(near_reps, near_reps.T) 126 | sim_matrix = (sim_matrix + 1) / 2 127 | # print((sim_matrix < 0).sum()) 128 | # print((rel_scores < 0).sum()) 129 | kernel_matrix = rel_scores[None] * sim_matrix * rel_scores[:, None] 130 | return near_reps, rel_scores, kernel_matrix 131 | 132 | def k_dpp_sampling(self, kernel_matrix, rel_scores, num_ice, num_candidates): 133 | ctxs_candidates_idx = [list(range(num_ice))] 134 | dpp_L = FiniteDPP('likelihood', **{'L': kernel_matrix}) 135 | i = 0 136 | while len(ctxs_candidates_idx) < num_candidates: 137 | try: 138 | samples_ids = np.array(dpp_L.sample_exact_k_dpp(size=num_ice, random_state=i)) 139 | except Exception as e: 140 | logger.info(e) 141 | i += 1 142 | if (i > 9999999): 143 | raise RuntimeError('Endless loop') 144 | continue 145 | i += 1 146 | # ordered by relevance score 147 | samples_scores = np.array([rel_scores[i] for i in samples_ids]) 148 | samples_ids = samples_ids[(-samples_scores).argsort()].tolist() 149 | 150 | if samples_ids not in ctxs_candidates_idx: 151 | assert len(samples_ids) == num_ice 152 | ctxs_candidates_idx.append(samples_ids) 153 | 154 | return ctxs_candidates_idx 155 | 156 | def dpp_search(self, entry, num_candidates=1, num_ice=1): 157 | candidates = self.knn_search(entry, num_ice=self.dpp_topk)[0] 158 | embed = np.expand_dims(entry['embed'], axis=0) 159 | near_reps, rel_scores, kernel_matrix = self.get_kernel(embed, candidates) 160 | 161 | if self.mode == "cand_selection": 162 | ctxs_candidates_idx = self.k_dpp_sampling(kernel_matrix=kernel_matrix, rel_scores=rel_scores, 163 | num_ice=num_ice, num_candidates=num_candidates) 164 | else: 165 | # MAP inference and create reordering candidates 166 | map_results = fast_map_dpp(kernel_matrix, num_ice) 167 | map_results = sorted(map_results) 168 | ctxs_candidates_idx = [map_results] 169 | while len(ctxs_candidates_idx) < num_candidates: 170 | # ordered by sim score 171 | ctxs_idx = map_results.copy() 172 | np.random.shuffle(ctxs_idx) 173 | if ctxs_idx not in ctxs_candidates_idx: 174 | ctxs_candidates_idx.append(ctxs_idx) 175 | 176 | ctxs_candidates = [] 177 | for ctxs_idx in ctxs_candidates_idx[:num_candidates]: 178 | ctxs_candidates.append([candidates[i] for i in ctxs_idx]) 179 | assert len(ctxs_candidates) == num_candidates 180 | 181 | return ctxs_candidates[0], ctxs_candidates 182 | 183 | def vote_k_select(self, embeddings=None, select_num=None, k=None, overlap_threshold=None, vote_file=None): 184 | n = len(embeddings) 185 | if vote_file is not None and os.path.isfile(vote_file): 186 | with open(vote_file) as f: 187 | vote_stat = json.load(f) 188 | else: 189 | # bar = tqdm(range(n), desc=f'vote {k} selection') 190 | vote_stat = defaultdict(list) 191 | 192 | for i in range(n): 193 | cur_emb = embeddings[i].reshape(1, -1) 194 | cur_scores = np.sum(cosine_similarity(embeddings, cur_emb), axis=1) 195 | sorted_indices = np.argsort(cur_scores).tolist()[-k - 1:-1] 196 | for idx in sorted_indices: 197 | if idx != i: 198 | vote_stat[idx].append(i) 199 | # bar.update(1) 200 | if vote_file is not None: 201 | with open(vote_file, 'w') as f: 202 | json.dump(vote_stat, f) 203 | votes = sorted(vote_stat.items(), key=lambda x: len(x[1]), reverse=True) 204 | j = 0 205 | selected_indices = [] 206 | while len(selected_indices) < select_num and j < len(votes): 207 | candidate_set = set(votes[j][1]) 208 | flag = True 209 | for pre in range(j): 210 | cur_set = set(votes[pre][1]) 211 | if len(candidate_set.intersection(cur_set)) >= overlap_threshold * len(candidate_set): 212 | flag = False 213 | break 214 | if not flag: 215 | j += 1 216 | continue 217 | selected_indices.append(int(votes[j][0])) 218 | j += 1 219 | if len(selected_indices) < select_num: 220 | unselected_indices = [] 221 | cur_num = len(selected_indices) 222 | for i in range(n): 223 | if not i in selected_indices: 224 | unselected_indices.append(i) 225 | selected_indices += random.sample(unselected_indices, select_num - cur_num) 226 | return selected_indices 227 | 228 | def vote_k_search(self, num_candidates=100, num_ice=8): 229 | return self.vote_k_idxs[:num_ice], [[i] for i in self.vote_k_idxs[:num_candidates]] 230 | 231 | def search(self, entry): 232 | if self.method == "random": 233 | return self.random_search(num_candidates=self.num_candidates, num_ice=self.num_ice) 234 | elif self.method == "topk": 235 | return self.knn_search(entry, num_candidates=self.num_candidates, num_ice=self.num_ice) 236 | elif self.method == "dpp" or self.dpp_sampling: 237 | return self.dpp_search(entry, num_candidates=self.num_candidates, num_ice=self.num_ice) 238 | elif self.method == "votek": 239 | return self.vote_k_search(num_candidates=self.num_candidates, num_ice=self.num_ice) 240 | 241 | def find(self): 242 | res_list = self.forward(self.dataloader) 243 | data_list = [] 244 | starttime = datetime.datetime.now() 245 | for entry in res_list: 246 | data = self.dataset_reader.dataset_wrapper[entry['metadata']['id']] 247 | ctxs, ctxs_candidates = self.search(entry) 248 | data['ctxs'] = ctxs 249 | data['ctxs_candidates'] = ctxs_candidates 250 | data_list.append(data) 251 | 252 | endtime = datetime.datetime.now() 253 | logger.info(f"retrieval time: {(endtime-starttime).seconds} seconds") 254 | with open(self.output_file, "w") as f: 255 | json.dump(data_list, f) 256 | 257 | 258 | @hydra.main(config_path="configs", config_name="prerank") 259 | def main(cfg): 260 | logger.info(cfg) 261 | if not cfg.overwrite: 262 | if os.path.exists(cfg.output_file): 263 | logger.info(f'{cfg.output_file} already exists,skip') 264 | return 265 | logger.info(cfg) 266 | dense_retriever = PreRank(cfg) 267 | random.seed(cfg.rand_seed) 268 | np.random.seed(cfg.rand_seed) 269 | dense_retriever.find() 270 | 271 | 272 | if __name__ == "__main__": 273 | main() -------------------------------------------------------------------------------- /src/datasets/instructions.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | PLACEHOLDER_C = "" 6 | PLACEHOLDER_X = "" 7 | PLACEHOLDER_Y = "" 8 | PLACEHOLDER_EXAMPLE = "" 9 | 10 | SINGLE_SENT_TASKS = ['imdb', 'sst2', 'yelp', 'rotten', 'elec', 'sst5'] 11 | SENT_PAIR_TASKS = ['rte', 'qnli', 'snli', 'mnli'] 12 | QA_TASKS = ['squad', 'adversarial_qa', 'mtop', 'iwslt17zh', 'xsum', 'iwslt17de', 'gsm8k', 'web_questions'] 13 | CHOICE_TASKS = ['commonsense_qa', 'copa'] 14 | 15 | 16 | def get_task_type(task_name): 17 | if task_name in SINGLE_SENT_TASKS: 18 | return "SINGLE_CLS" 19 | elif task_name in SENT_PAIR_TASKS: 20 | return "PAIR_CLS" 21 | elif task_name in QA_TASKS: 22 | return "QA" 23 | elif task_name in CHOICE_TASKS: 24 | return "CHOICE" 25 | else: 26 | return "OTHER" 27 | 28 | 29 | def get_choice_sep(task_name): 30 | if task_name == 'commonsense_qa': 31 | return ', ' 32 | elif task_name == 'copa': 33 | return '||||' 34 | return None 35 | 36 | 37 | # 选用不同的template 38 | def get_template(task_name, t=1): 39 | if t == 1: 40 | if task_name == "sst2" or task_name == "imdb": 41 | return { 42 | 0: { 43 | "example_instruction": "Negative Movie Review: \"\"", 44 | "prompting_instruction": "Negative Movie Review: \"\"", 45 | }, 46 | 1: { 47 | "example_instruction": "Positive Movie Review: \"\"", 48 | "prompting_instruction": "Positive Movie Review: \"\"", 49 | } 50 | } 51 | elif task_name == "commonsense_qa": 52 | return { 53 | 0: { 54 | "example_instruction": "Answer the following question:\n \n Answer: ", 55 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 56 | }, 57 | 1: { 58 | "example_instruction": "Answer the following question:\n \n Answer: ", 59 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 60 | }, 61 | 2: { 62 | "example_instruction": "Answer the following question:\n \n Answer: ", 63 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 64 | }, 65 | 3: { 66 | "example_instruction": "Answer the following question:\n \n Answer: ", 67 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 68 | }, 69 | 4: { 70 | "example_instruction": "Answer the following question:\n \n Answer: ", 71 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 72 | } 73 | } 74 | elif task_name == 'boolq': 75 | return { 76 | 0: { # False 77 | "example_instruction": "\n Can we know based on context above? No.", 78 | "prompting_instruction": "\n Can we know based on context above? No.", 79 | }, 80 | 1: { # True 81 | "example_instruction": "\n Can we know based on context above? Yes.", 82 | "prompting_instruction": "\n Can we know based on context above? Yes.", 83 | } 84 | } 85 | elif task_name == "commonsense_qa": 86 | return { 87 | 0: { 88 | "example_instruction": " The answer is ", 89 | "prompting_instruction": " The answer is ", 90 | }, 91 | 1: { 92 | "example_instruction": " The answer is ", 93 | "prompting_instruction": " The answer is ", 94 | }, 95 | 2: { 96 | "example_instruction": " The answer is ", 97 | "prompting_instruction": " The answer is ", 98 | }, 99 | 3: { 100 | "example_instruction": " The answer is ", 101 | "prompting_instruction": " The answer is ", 102 | }, 103 | 4: { 104 | "example_instruction": " The answer is ", 105 | "prompting_instruction": " The answer is ", 106 | } 107 | } 108 | elif task_name == "copa" or task_name == "piqa": 109 | return { 110 | 0: { 111 | "example_instruction": " ", 112 | "prompting_instruction": " ", 113 | }, 114 | 1: { 115 | "example_instruction": " ", 116 | "prompting_instruction": " ", 117 | } 118 | } 119 | elif task_name == "sst5": 120 | return { 121 | 0: { 122 | "example_instruction": "\"\" It is terrible.", 123 | "prompting_instruction": "\"\" It is terrible.", 124 | }, 125 | 1: { 126 | "example_instruction": "\"\" It is bad.", 127 | "prompting_instruction": "\"\" It is bad.", 128 | }, 129 | 2: { 130 | "example_instruction": "\"\" It is okey.", 131 | "prompting_instruction": "\"\" It is okey.", 132 | }, 133 | 3: { 134 | "example_instruction": "\"\" It is good.", 135 | "prompting_instruction": "\"\" It is good.", 136 | }, 137 | 4: { 138 | "example_instruction": "\"\" It is great.", 139 | "prompting_instruction": "\"\" It is great.", 140 | } 141 | } 142 | elif task_name == "ag_news": 143 | return { 144 | 0: { 145 | "example_instruction": "\"\" It is about world.", 146 | "prompting_instruction": "\"\" It is about world.", 147 | }, 148 | 1: { 149 | "example_instruction": "\"\" It is about sports.", 150 | "prompting_instruction": "\"\" It is about sports.", 151 | }, 152 | 2: { 153 | "example_instruction": "\"\" It is about business.", 154 | "prompting_instruction": "\"\" It is about business.", 155 | }, 156 | 3: { 157 | "example_instruction": "\"\" It is about science and technology.", 158 | "prompting_instruction": "\"\" It is about science and technology.", 159 | } 160 | } 161 | elif task_name == "trec": 162 | return { 163 | 0: { 164 | "example_instruction": "\"\" It is about abbreviation.", 165 | "prompting_instruction": "\"\" It is about abbreviation.", 166 | }, 167 | 1: { 168 | "example_instruction": "\"\" It is about entity.", 169 | "prompting_instruction": "\"\" It is about entity.", 170 | }, 171 | 2: { 172 | "example_instruction": "\"\" It is about description and abstract concept.", 173 | "prompting_instruction": "\"\" It is about description and abstract concept.", 174 | }, 175 | 3: { 176 | "example_instruction": "\"\" It is about human being.", 177 | "prompting_instruction": "\"\" It is about human being.", 178 | }, 179 | 4: { 180 | "example_instruction": "\"\" It is about location.", 181 | "prompting_instruction": "\"\" It is about location.", 182 | }, 183 | 5: { 184 | "example_instruction": "\"\" It is about numeric value.", 185 | "prompting_instruction": "\"\" It is about numeric value.", 186 | } 187 | } 188 | elif task_name == "yelp_polarity": 189 | return { 190 | 0: { 191 | "example_instruction": "Negative Restaurant Review: \"\"", 192 | "prompting_instruction": "Negative Restaurant Review: \"\"", 193 | }, 194 | 1: { 195 | "example_instruction": "Positive Restaurant Review: \"\"", 196 | "prompting_instruction": "Positive Restaurant Review: \"\"", 197 | } 198 | } 199 | elif task_name == "mtop": 200 | return { 201 | 0: { 202 | "example_instruction": "\t", 203 | "prompting_instruction": "\t", 204 | } 205 | } 206 | elif task_name == "squad" or task_name == "web_questions": 207 | return { 208 | 0: { 209 | "example_instruction": "\t\t", 210 | "prompting_instruction": "\t\t" 211 | } 212 | } 213 | elif task_name == 'gsm8k': 214 | return { 215 | 0: { 216 | "example_instruction": "Solve the follow math problem: \n Answer: ", 217 | "prompting_instruction": "Solve the follow math problem: \n Answer: " 218 | } 219 | } 220 | elif task_name == 'iwslt17zh': 221 | return { 222 | 0: { 223 | "example_instruction": "What is the Chinese translation of : ", 224 | "prompting_instruction": "What is the Chinese translation of : " 225 | } 226 | } 227 | elif task_name == 'iwslt17de': 228 | return { 229 | 0: { 230 | "example_instruction": "What is the German translation of : ", 231 | "prompting_instruction": "What is the German translation of : " 232 | } 233 | } 234 | elif task_name == "mnli" or task_name == "snli": 235 | return { 236 | 0: { # entailment 237 | "example_instruction": "? Yes, ", 238 | "prompting_instruction": "? Yes, ", 239 | }, 240 | 1: { # neutral 241 | "example_instruction": "? Maybe, ", 242 | "prompting_instruction": "? Maybe, ", 243 | }, 244 | 2: { # contradiction 245 | "example_instruction": "? No, ", 246 | "prompting_instruction": "? No, ", 247 | } 248 | } 249 | elif task_name == 'qnli': 250 | return { 251 | 0: { # entailment 252 | "example_instruction": " Can we know ? Yes.", 253 | "prompting_instruction": " Can we know ? Yes.", 254 | }, 255 | 1: { # contradiction 256 | "example_instruction": " Can we know ? No.", 257 | "prompting_instruction": " Can we know ? No.", 258 | } 259 | } 260 | elif task_name == 'rte': 261 | return { 262 | 0: { # entailment 263 | "example_instruction": "? Yes, ", 264 | "prompting_instruction": "? Yes, ", 265 | }, 266 | 1: { # contradiction 267 | "example_instruction": "? No, ", 268 | "prompting_instruction": "? No, ", 269 | } 270 | } 271 | elif task_name == 'mrpc': 272 | return { 273 | 0: { # entailment 274 | "example_instruction": " Yes , ", 275 | "prompting_instruction": " Yes , ", 276 | }, 277 | 1: { # contradiction 278 | "example_instruction": " No , ", 279 | "prompting_instruction": " No , ", 280 | } 281 | } 282 | elif task_name == 'xsum': 283 | return { 284 | 0: { 285 | "example_instruction": "Document: Summary: ", 286 | "prompting_instruction": "Document: Summary: " 287 | } 288 | } 289 | 290 | elif t == 2: 291 | if task_name == "sst2" or task_name == "imdb": 292 | return { 293 | 0: { 294 | "example_instruction": "\"\" It is terrible.", 295 | "prompting_instruction": "\"\" It is terrible.", 296 | }, 297 | 1: { 298 | "example_instruction": "\"\" It is great.", 299 | "prompting_instruction": "\"\" It is great.", 300 | } 301 | } 302 | elif task_name == "piqa": 303 | return { 304 | 0: { 305 | "example_instruction": "\nWhich is the correct ending? \n ", 306 | "prompting_instruction": "\nWhich is the correct ending? \n ", 307 | }, 308 | 1: { 309 | "example_instruction": "\nWhich is the correct ending? \n ", 310 | "prompting_instruction": "\nWhich is the correct ending? \n ", 311 | } 312 | } 313 | elif task_name == "squad": 314 | return { 315 | 0: { 316 | "example_instruction": "The context is: \"\"\nThe answer to the question \"\" is: \"", 317 | "prompting_instruction": "The context is: \"\"\nThe answer to the question \"\" is: \"" 318 | } 319 | } 320 | elif task_name == "trec": 321 | return { 322 | 0: { 323 | "example_instruction": "\"\" This is about abbreviation.", 324 | "prompting_instruction": "\"\" It is about abbreviation.", 325 | }, 326 | 1: { 327 | "example_instruction": "\"\" It is about entity.", 328 | "prompting_instruction": "\"\" It is about entity.", 329 | }, 330 | 2: { 331 | "example_instruction": "\"\" It is about description and abstract concept.", 332 | "prompting_instruction": "\"\" It is about description and abstract concept.", 333 | }, 334 | 3: { 335 | "example_instruction": "\"\" It is about human being.", 336 | "prompting_instruction": "\"\" It is about human being.", 337 | }, 338 | 4: { 339 | "example_instruction": "\"\" It is about location.", 340 | "prompting_instruction": "\"\" It is about location.", 341 | }, 342 | 5: { 343 | "example_instruction": "\"\" It is about numeric value.", 344 | "prompting_instruction": "\"\" It is about numeric value.", 345 | } 346 | } 347 | elif task_name == "mnli" or task_name == "snli": 348 | return { 349 | 0: { # entailment 350 | "example_instruction": "Input: \"\" implies \"\"\n Answer: true", 351 | "prompting_instruction": " Input: \"\" implies \"\"\n Answer: true", 352 | }, 353 | 1: { # neutral 354 | "example_instruction": "Input: \"\" implies \"\"\n Answer: inconclusive", 355 | "prompting_instruction": " Input: \"\" implies \"\"\n Answer: inconclusive", 356 | }, 357 | 2: { # contradiction 358 | "example_instruction": "Input: \"\" implies \"\"\n Answer: false", 359 | "prompting_instruction": " Input: \"\" implies \"\"\n Answer: false", 360 | } 361 | } 362 | elif task_name == 'rte' or task_name == 'mrpc': 363 | return { 364 | 0: { # entailment 365 | "example_instruction": "Can implies ? Yes", 366 | "prompting_instruction": "Can implies ? Yes", 367 | }, 368 | 1: { # contradiction 369 | "example_instruction": "Can implies ? No", 370 | "prompting_instruction": "Can implies ? No", 371 | } 372 | } 373 | elif task_name == 'iwslt17zh': 374 | return { 375 | 0: { 376 | "example_instruction": "English: \n Chinese translation: ", 377 | "prompting_instruction": "English: \n Chinese translation: " 378 | } 379 | } 380 | elif task_name == 'iwslt17de': 381 | return { 382 | 0: { 383 | "example_instruction": "English: \n German translation: ", 384 | "prompting_instruction": "English: \n German translation: " 385 | } 386 | } 387 | elif task_name == "sst5": 388 | return { 389 | 0: { 390 | "example_instruction": "Review: \nSentiment: terrible", 391 | "prompting_instruction": "Review: \nSentiment: terrible", 392 | }, 393 | 1: { 394 | "example_instruction": "Review: \nSentiment: bad", 395 | "prompting_instruction": "Review: \nSentiment: bad", 396 | }, 397 | 2: { 398 | "example_instruction": "Review: \nSentiment: okay", 399 | "prompting_instruction": "Review: \nSentiment: okay", 400 | }, 401 | 3: { 402 | "example_instruction": "Review: \nSentiment: good", 403 | "prompting_instruction": "Review: \nSentiment: good", 404 | }, 405 | 4: { 406 | "example_instruction": "Review: \nSentiment: great", 407 | "prompting_instruction": "Review: \nSentiment: great", 408 | } 409 | } 410 | elif task_name == "trec": 411 | return { 412 | 0: { 413 | "example_instruction": "Input: \n Topic: abbreviation.", 414 | "prompting_instruction": "Input: \n Topic: abbreviation.", 415 | }, 416 | 1: { 417 | "example_instruction": "Input: \n Topic: entity.", 418 | "prompting_instruction": "Input: \n Topic: entity.", 419 | }, 420 | 2: { 421 | "example_instruction": "Input: \n Topic: description and abstract concept.", 422 | "prompting_instruction": "Input: \n Topic: description and abstract concept.", 423 | }, 424 | 3: { 425 | "example_instruction": "Input: \n Topic: human being.", 426 | "prompting_instruction": "Input: \n Topic: human being.", 427 | }, 428 | 4: { 429 | "example_instruction": "Input: \n Topic: location.", 430 | "prompting_instruction": "Input: \n Topic: location.", 431 | }, 432 | 5: { 433 | "example_instruction": "Input: \n Topic: numeric value.", 434 | "prompting_instruction": "Input: \n Topic: numeric value.", 435 | } 436 | } 437 | elif task_name == "commonsense_qa": 438 | return { 439 | 0: { 440 | "example_instruction": "Answer the following question:\n \n Answer: ", 441 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 442 | }, 443 | 1: { 444 | "example_instruction": "Answer the following question:\n \n Answer: ", 445 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 446 | }, 447 | 2: { 448 | "example_instruction": "Answer the following question:\n \n Answer: ", 449 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 450 | }, 451 | 3: { 452 | "example_instruction": "Answer the following question:\n \n Answer: ", 453 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 454 | }, 455 | 4: { 456 | "example_instruction": "Answer the following question:\n \n Answer: ", 457 | "prompting_instruction": "Answer the following question:\n \n Answer: ", 458 | } 459 | } 460 | elif t == 3: 461 | if task_name == "squad": 462 | return { 463 | 0: { 464 | "example_instruction": "The answer to the question \"\" is: \"", 465 | "prompting_instruction": "The context is: \"\"\nThe answer to the question \"\" is: \"" 466 | } 467 | } 468 | 469 | return None 470 | 471 | 472 | 473 | def build_instruction(instruction, c=None, x=None, y_text=None, e=None, e_instruction=None, tokenizer=None, max_len=700, 474 | C_KEY='C', X_KEY='X', Y_KEY='Y', Y_TEXT_KEY='Y_TEXT', reverse=True, need_span_ids=False, 475 | prior=False, prior_no=1, choice=False, choice_sep=', '): 476 | """ 477 | 478 | Args: 479 | choice_sep: 480 | choice: 481 | prior_no: 482 | prior: 将c x 替换为等量的mask 483 | Y_TEXT_KEY: 484 | need_span_ids: 如果为True,会在instruction后面拼上需要计算prob对应的字符位置. 格式为". 12 24",直接rfind(‘.’)再split(' ')即可 485 | 一般找除了example之外的所有部分,包括X和Y,暂且命名为X 486 | reverse: 487 | Y_KEY: 488 | X_KEY: 489 | C_KEY: 490 | instruction: prompting_instruction 491 | c: sentence1 492 | x: sentence2 493 | y_text: label text 494 | e: example list [{'C': str,'X':str,'Y':int},...] 495 | e_instruction: {label1:example_instruction1,...} 496 | tokenizer: 497 | max_len: 498 | Returns: 499 | 500 | """ 501 | output = instruction 502 | 503 | if choice: 504 | choices = c.replace(' or ', choice_sep).replace('?', '').split(choice_sep) 505 | for i in range(len(choices)): 506 | output = output.replace(f'', choices[i]) 507 | 508 | if prior: 509 | if prior_no == 1: 510 | if c is not None: 511 | c_len = len(tokenizer.tokenize(c)) 512 | c = ' '.join(['x' for i in range(c_len)]) 513 | 514 | if x is not None: 515 | x_len = len(tokenizer.tokenize(x)) 516 | x = ' '.join(['x' for i in range(x_len)]) 517 | 518 | if y_text is not None: # not used 519 | y_len = len(tokenizer.tokenize(y_text)) 520 | y_text = ' '.join(['x' for i in range(y_len)]) 521 | elif prior_no == 2: 522 | c = 'N/A' 523 | x = 'N/A' 524 | y_text = 'N/A' 525 | elif prior_no == 3: 526 | c = '[MASK]' 527 | x = '[MASK]' 528 | y_text = '[MASK]' 529 | 530 | if c is not None: 531 | output = output.replace(PLACEHOLDER_C, c) 532 | 533 | if x is not None: 534 | output = output.replace(PLACEHOLDER_X, x) 535 | 536 | if y_text is not None: # not used 537 | output = output.replace(PLACEHOLDER_Y, y_text) 538 | 539 | if e is not None: 540 | 541 | cur_len = len(tokenizer.tokenize(output)) 542 | if cur_len > max_len: 543 | logging.info(f'x is too long {cur_len}') 544 | t = ' '.join(tokenizer.tokenize(output)[-cur_len + 30:]) 545 | return f'{t}. 0 1', [] 546 | 547 | # print(e) 548 | keep_exs = [] 549 | keep_ex_strs = [] 550 | if len(e) == 0: 551 | output = output.replace(PLACEHOLDER_EXAMPLE, "") 552 | if need_span_ids: 553 | span_begin = 0 554 | span_end = len(output) 555 | output += f'. {span_begin} {span_end}' 556 | 557 | else: 558 | total_len = len(tokenizer.tokenize(output)) 559 | for _ex in e: 560 | if torch.is_tensor(_ex[Y_KEY]): 561 | _ex_str = build_instruction(instruction=e_instruction[_ex[Y_KEY].item()], 562 | c=_ex[C_KEY] if C_KEY in _ex else None, 563 | x=_ex[X_KEY], 564 | y_text=_ex[Y_TEXT_KEY] if Y_TEXT_KEY in _ex else None) 565 | else: 566 | _ex_str = build_instruction(instruction=e_instruction[_ex[Y_KEY]], 567 | c=_ex[C_KEY] if C_KEY in _ex else None, 568 | x=_ex[X_KEY], 569 | y_text=_ex[Y_TEXT_KEY] if Y_TEXT_KEY in _ex else None) 570 | _len = len(tokenizer.tokenize(_ex_str)) 571 | if _len + total_len <= max_len: 572 | keep_exs.append(_ex) 573 | keep_ex_strs.append(_ex_str) 574 | total_len += _len 575 | else: 576 | break 577 | if reverse: 578 | keep_ex_strs.reverse() 579 | 580 | if need_span_ids: # concat the begin and end positions e.g.". 12 24", 581 | span_begin = output.rfind(PLACEHOLDER_EXAMPLE) 582 | span_end = len(output) 583 | if span_begin == -1: 584 | span_begin = 0 585 | span_len = span_end - span_begin 586 | 587 | output = output.replace(PLACEHOLDER_EXAMPLE, '\n\n'.join(keep_ex_strs) + '\n\n') 588 | 589 | if need_span_ids: 590 | span_end = len(output) 591 | span_begin = span_end - span_len 592 | output += f'. {span_begin} {span_end}' # [ ) 593 | 594 | return output, keep_exs # return the in-context examples with (possibly) reduced number 595 | 596 | # if any placeholder is not set yet, reset to "" 597 | output = output.replace(PLACEHOLDER_C, "").replace(PLACEHOLDER_X, ""). \ 598 | replace(PLACEHOLDER_Y, "").replace(PLACEHOLDER_EXAMPLE, "") 599 | return output 600 | --------------------------------------------------------------------------------