├── .DS_Store ├── README.md ├── conformal_risk_computation ├── conformal_distribution_shift_generation_risk.py ├── conformal_distribution_shift_generation_risk_comparison.py ├── conformal_generation_risk.py ├── conformal_generation_risk_comparison.py ├── conformal_generation_risk_multi_dim_config.py ├── conformal_generation_risk_num_gen.py ├── conformal_generation_risk_similarity_threshold.py ├── conformal_generation_risk_valid_config.py └── utils.py ├── ds_config.json ├── figs └── README.md ├── requirements.txt ├── run └── conformal_parallel.sh ├── scripts_conformal_generation_risk ├── run_conformal_distribution_shift_generation_risk.sh ├── run_conformal_distribution_shift_generation_risk_comparisons.sh ├── run_conformal_generation_risk.sh ├── run_conformal_generation_risk_comparisons.sh ├── run_conformal_generation_risk_multi_dim_config.sh ├── run_conformal_generation_risk_num_gen.sh ├── run_conformal_generation_risk_similarity_threshold.sh └── run_conformal_generation_risk_valid_config.sh ├── scripts_raw_risk_scores ├── baai.sh ├── bm25.sh ├── llm-r.sh └── openai.sh └── src ├── .DS_Store ├── collators ├── __init__.py ├── biencoder_collator.py ├── cross_encoder_collator.py └── gpt2_collator.py ├── compute_controlled_risk.py ├── config.py ├── conformal_calibration_empirical_risk.py ├── conformal_calibration_guarantee.py ├── conformal_simulation_risks.py ├── data_utils.py ├── evaluation ├── __init__.py ├── baai_eval.py ├── base_eval.py ├── bm25_eval.py ├── dense_eval.py ├── metrics.py ├── openai_eval.py ├── qa_utils.py └── random_eval.py ├── inference ├── __init__.py ├── gen_llm_scores.py ├── gen_reward_scores.py ├── generate_few_shot_prompt.py ├── inference_utils.py └── search_topk.py ├── llm_calibrator.py ├── llm_evaluator.py ├── llm_simulator.py ├── llms ├── __init__.py ├── base_llm.py ├── gpt2.py ├── gpt_neo.py └── llama.py ├── loaders ├── __init__.py ├── biencoder_dataloader.py ├── cross_encoder_dataloader.py └── loader_utils.py ├── logger_config.py ├── main_eval.py ├── model_utils.py ├── models ├── __init__.py ├── biencoder_model.py ├── cross_encoder_model.py ├── simple_encoder.py └── simple_retriever.py ├── tasks ├── __init__.py ├── aeslc.py ├── agnews.py ├── arc.py ├── base_task.py ├── boolq.py ├── common_gen.py ├── copa.py ├── dart.py ├── e2e_nlg.py ├── gigaword.py ├── hellaswag.py ├── mnli.py ├── mrpc.py ├── multirc.py ├── nq.py ├── openbookqa.py ├── paws.py ├── piqa.py ├── qnli.py ├── qqp.py ├── rte.py ├── sentiment140.py ├── snli.py ├── squad_v1.py ├── sst2.py ├── winogrande.py ├── wsc.py ├── wsc273.py └── yelp.py ├── train_biencoder.py ├── train_cross_encoder.py ├── trainers ├── __init__.py ├── biencoder_trainer.py └── reward_trainer.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kangmintong/C-RAG/5bda4382bebf6453b3b8dccb8a52f4d015e95fee/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C-RAG: Certified Generation Risks for Retrieval-Augmented Language Models [ICML 2024] 2 | 3 | We provide the implementation of [C-RAG](https://arxiv.org/abs/2402.03181) in this repositary. 4 | 5 | C-RAG is the first framework to certify generation risks for RAG models. Specifically, C-RAG provides conformal risk analysis for RAG models and certify an upper confidence bound of generation risks, which is refered to as conformal generation risk. 6 | C-RAG also provides theoretical guarantees on conformal generation risks for general bounded risk functions under test distribution shifts. 7 | C-RAG proves that RAG achieves a lower conformal generation risk than that of a single LLM when the quality of the retrieval model and transformer is non-trivial. 8 | The intensive empirical results demonstrate the soundness and tightness of the conformal generation risk guarantees across four widely-used NLP datasets on four state-of-the-art retrieval models. 9 | 10 | ## Environment 11 | 12 | Install PyTorch with correponding environment and CUDA version at [Pytorch Installation](https://pytorch.org/get-started/locally/). 13 | 14 | Run ``pip install -r requirement.txt`` for installation of other neccessary packages in the repo. 15 | 16 | ## Pretrained models 17 | For the supervised-finetuned biencoder-based retrieval model, we follow the implementation in [LLM-R](https://arxiv.org/abs/2307.07164) and provide the model checkpoint at [trained_retrieval_models](https://drive.google.com/file/d/1xOeCz3vt2piHuY00a5q4YCNhDkyCs0VF/view?usp=sharing). 18 | 19 | Or you can download it by command: 20 | ``` 21 | gdown https://drive.google.com/uc?id=1xOeCz3vt2piHuY00a5q4YCNhDkyCs0VF 22 | ``` 23 | 24 | Then, put the folder ``trained_retrieval_models/`` under ``C-RAG/``. 25 | 26 | ## Dataset preparation 27 | We evaluate C-RAG on four widely used NLP datasets, including AESLC, CommonGen, DART, and E2E. We preprocess the data and provide it at [data](https://drive.google.com/file/d/1JJC192wdOmXYZy_hXcGVrXOtMK2LWsv7/view?usp=sharing). 28 | 29 | Or you can download it by command: 30 | ``` 31 | gdown https://drive.google.com/uc?id=1JJC192wdOmXYZy_hXcGVrXOtMK2LWsv7 32 | ``` 33 | 34 | Then, put the folder ``data/`` under ``C-RAG/``. 35 | 36 | ## Evaluate conformal generation risks in C-RAG 37 | 38 | To compute the conformal generation risk, we need to (1) evaluate the raw risk scores for calibration instances following our constrained generation protocol, and (2) compute the conformal generation risks based on empirical risk statistics. 39 | 40 | ### (1) Evaluate raw risk scores for calibration instances 41 | 42 | #### We compact the process in four scripts for four retrieval models 43 | 44 | Evaluate raw risk scores for BM25 retrieval model: 45 | ``` 46 | sh scripts_raw_risk_scores/bm25.sh 47 | ``` 48 | 49 | Evaluate raw risk scores for BAAI/bge retrieval model: 50 | ``` 51 | sh scripts_raw_risk_scores/baai.sh 52 | ``` 53 | 54 | Evaluate raw risk scores for OpenAI/text-embedding-ada-002 retrieval model: 55 | ``` 56 | sh scripts_raw_risk_scores/openai.sh 57 | ``` 58 | 59 | Evaluate raw risk scores for LLM-R finetuned biencoder-based retrieval model: 60 | ``` 61 | sh scripts_raw_risk_scores/llm-r.sh 62 | ``` 63 | 64 | #### Exaplanation: we compact the following two steps in the scripts above: 65 | 66 | 1. Prepare the prompt via ``src/inference/generate_few_shot_prompt.py``:
Retrieve relevant examples and store the prompts at `` outputs/{METHOD}/{METHOD}_test_k{N_RAG}.jsonl.gz`` 67 | 2. Evaluate the risks of prompts on calibration sets via ``src/conformal_calibration_empirical_risk.py``:
Evaluate the prompts and store results in ``outputs/{METHOD}/{LLM}_{METHOD}/`` 68 | 69 | 70 | ### (2) Compute conformal generation risks 71 | 72 | The conformal generation risk computation is based on empirical risk statistics stored at ``outputs/{METHOD}/{LLM}_{METHOD}/`` in step (1). 73 | 74 | #### Conformal generation risk without distribution shifts 75 | 1. Compute conformal generation risks of a single retrieval model and compare it with the simulation results: 76 | ``` 77 | sh scripts_conformal_generation_risk/run_conformal_generation_risk.sh 78 | ``` 79 | 2. Compare conformal generation risks of different retrieval models (after running step 1 for corresponding methods): 80 | ``` 81 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_comparisons.sh 82 | ``` 83 | 84 | #### Conformal generation risk with distribution shifts 85 | 1. Compute conformal generation risks of a single retrieval model and compare it with the simulation results: 86 | ``` 87 | sh scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk.sh 88 | ``` 89 | 2. Compare conformal generation risks of different retrieval models (after running step 1 for corresponding methods): 90 | ``` 91 | sh scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk_comparisons.sh 92 | ``` 93 | 94 | #### Conformal generation risk with multi-dimensional RAG configurations 95 | ``` 96 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_multi_dim_config.sh 97 | ``` 98 | 99 | #### Valid configurations given desired risk levels 100 | ``` 101 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_valid_config.sh 102 | ``` 103 | 104 | #### Additional evaluations with varying RAG configurations 105 | 106 | Conformal generation risks with varying generation set sizes: 107 | ``` 108 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_num_gen.sh 109 | ``` 110 | 111 | Conformal generation risks with varying generation similar thresholds: 112 | ``` 113 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_similarity_threshold.sh 114 | ``` 115 | 116 | 117 | ## Acknowledgement 118 | 119 | The inference part in the repo is built on [LLM-R repo](https://github.com/microsoft/LMOps/tree/main/llm_retriever). 120 | 121 | For any related questions or discussion, please contact ``mintong2@illinois.edu``. 122 | 123 | If you find our paper and repo useful for your research, please consider cite: 124 | ``` 125 | @article{kang2024c, 126 | title={C-RAG: Certified Generation Risks for Retrieval-Augmented Language Models}, 127 | author={Kang, Mintong and G{\"u}rel, Nezihe Merve and Yu, Ning and Song, Dawn and Li, Bo}, 128 | journal={arXiv preprint arXiv:2402.03181}, 129 | year={2024} 130 | } 131 | ``` 132 | -------------------------------------------------------------------------------- /conformal_risk_computation/conformal_distribution_shift_generation_risk_comparison.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import numpy as np 6 | from matplotlib.ticker import MaxNLocator 7 | from tqdm import tqdm 8 | from scipy.stats import binom 9 | from typing import List, Union, Optional, Tuple, Mapping, Dict 10 | import os 11 | from numpy import linalg 12 | import numpy as np 13 | import argparse 14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, cal_dist_shift_bound, get_color_dict, get_max_x_all 15 | 16 | if __name__=='__main__': 17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.') 18 | 19 | parser.add_argument('--llm_retrieval_methods', '--names-list', nargs='+') 20 | parser.add_argument('--datasets', nargs='+') 21 | parser.add_argument('--n_rag', type=int, default=15) 22 | parser.add_argument('--num_gen', type=int, default=1) 23 | parser.add_argument('--gen_thresh', type=int, default=0) 24 | 25 | args = parser.parse_args() 26 | 27 | method2name = get_method2name() 28 | method2method = get_method2method() 29 | 30 | datasets = args.datasets 31 | retrieval_methods = args.llm_retrieval_methods 32 | 33 | max_x_all = get_max_x_all() 34 | 35 | n_rag = args.n_rag 36 | num_gen = args.num_gen 37 | gen_thresh = args.gen_thresh 38 | 39 | for dataset in datasets: 40 | max_x = max_x_all[dataset] 41 | hellinger_distances = np.array(list(range(0, max_x, 1))) 42 | hellinger_distances = hellinger_distances / 100.0 43 | 44 | alphas = {} 45 | min_risk, max_risk = 10e5, -10e5 46 | 47 | for method in retrieval_methods: 48 | 49 | alpha_list = [] 50 | evaluate_results = json.load(open(f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{gen_thresh}_calibration_result.json','r', encoding='utf-8')) 51 | 52 | 53 | plot_dist_shift_results = {} 54 | 55 | 56 | results = list(evaluate_results.values())[0][0] 57 | results = np.array(results) 58 | r_hat = np.mean(results) 59 | r_hat = 1. - r_hat 60 | 61 | for dist in tqdm(hellinger_distances): 62 | 63 | temp_res = cal_dist_shift_bound(r_hat, dist, len(results), np.var(results), 0.1) 64 | plot_dist_shift_results[dist] = temp_res 65 | alpha_list.append(temp_res) 66 | 67 | if temp_res < min_risk: 68 | min_risk = temp_res 69 | if temp_res > max_risk: 70 | max_risk = temp_res 71 | 72 | alphas[method] = alpha_list 73 | 74 | data = { 75 | 'dist': hellinger_distances, 76 | } 77 | for method in retrieval_methods: 78 | data[method] = alphas[method] 79 | 80 | df = pd.DataFrame(data) 81 | 82 | plt.style.use('seaborn-v0_8-darkgrid') 83 | plt.figure(figsize=(10, 8)) 84 | 85 | color_dict = get_color_dict(retrieval_methods) 86 | 87 | for metric, color in color_dict.items(): 88 | plt.plot(df['dist'], df[metric], marker='*', color=color, label=method2name[metric], markersize=20) 89 | 90 | fontsize = 32 91 | plt.ylim([min_risk - 0.07, min(max_risk + 0.02, 1.02)]) 92 | plt.xticks(fontsize=fontsize) 93 | plt.yticks(fontsize=fontsize) 94 | 95 | ax = plt.gca() 96 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 97 | ax.xaxis.set_major_formatter('{x:3<0.2f}') 98 | ax.yaxis.set_major_formatter('{x:3<0.2f}') 99 | 100 | # Show the plot 101 | plt.tight_layout() 102 | plt.savefig(f'./figs/{dataset}_distribution_shift_conformal_risk.jpg', dpi=800) -------------------------------------------------------------------------------- /conformal_risk_computation/conformal_generation_risk.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import numpy as np 6 | from matplotlib.ticker import MaxNLocator 7 | from tqdm import tqdm 8 | from scipy.stats import binom 9 | from typing import List, Union, Optional, Tuple, Mapping, Dict 10 | import os 11 | from numpy import linalg 12 | import numpy as np 13 | import argparse 14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func 15 | 16 | if __name__=='__main__': 17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.') 18 | 19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base']) 20 | parser.add_argument('--datasets', nargs='+') 21 | parser.add_argument('--num_gen', type=int, default=1) 22 | parser.add_argument('--gen_thresh', type=int, default=0) 23 | parser.add_argument('--n_rag_list', nargs='+') 24 | 25 | args = parser.parse_args() 26 | 27 | method2name = get_method2name() 28 | method2method = get_method2method() 29 | 30 | datasets = args.datasets 31 | method = args.llm_retrievalmodel 32 | 33 | num_point_dict, sample_size_dict = get_num_points() 34 | 35 | 36 | for dataset in datasets: 37 | num_simulation_point = num_point_dict[dataset] 38 | sample_size_per_point = sample_size_dict[dataset] 39 | 40 | num_gen = args.num_gen 41 | gen_thresh = args.gen_thresh 42 | n_rag_list = args.n_rag_list 43 | 44 | alphas = {} 45 | min_risk, max_risk = 10e5, -10e5 46 | alpha_list = [] 47 | 48 | for n_rag in n_rag_list: 49 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{gen_thresh}_calibration_result.json' 50 | results = json.load(open(path, 'r', encoding='utf-8')) 51 | results = list(results.values())[0][0] 52 | results = np.array(results) 53 | alpha_list.append(empirical_risk2alpha_func(1.-np.mean(results), N_cal=len(results), delta=0.1)) 54 | if alpha_list[-1]max_risk: 57 | max_risk = alpha_list[-1] 58 | alphas[method] = alpha_list 59 | 60 | data = { 61 | 'N_rag': n_rag_list, 62 | method: alphas[method], 63 | } 64 | 65 | df = pd.DataFrame(data) 66 | 67 | # Set the style 68 | plt.style.use('seaborn-v0_8-darkgrid') 69 | plt.figure(figsize=(10,8)) 70 | 71 | color_dict = {method: 'black'} 72 | 73 | for metric, color in color_dict.items(): 74 | plt.plot(df['N_rag'], df[metric], marker='^', color=color, label=r'Certified Conformal Risk $\alpha_{\text{rag}}$', markersize=20) 75 | 76 | 77 | simulation_points = [] 78 | for n_rag in n_rag_list: 79 | results = json.load(open(f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{gen_thresh}_calibration_result.json', 'r', encoding='utf-8')) 80 | results = list(results.values())[0][0] 81 | results = np.array(results) 82 | 83 | simulation_points_temp = [] 84 | for idx_point in range(num_simulation_point): 85 | test_list = random.sample(list(range(len(results))),sample_size_per_point) 86 | new_res = 1. - np.mean(results[test_list]) 87 | simulation_points_temp.append(new_res) 88 | simulation_points_temp = np.array(simulation_points_temp) 89 | for x in simulation_points_temp: 90 | if x < min_risk: 91 | min_risk = x 92 | plt.scatter(x=[n_rag]*len(simulation_points_temp),y=simulation_points_temp,color='gray',alpha=0.7,s=100) 93 | 94 | fontsize=45 95 | 96 | plt.ylim([min_risk-0.001, max_risk+0.02]) 97 | plt.xticks(fontsize=fontsize) 98 | plt.yticks(fontsize=fontsize) 99 | 100 | ax = plt.gca() 101 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 102 | ax.yaxis.set_major_formatter('{x:3<0.2f}') 103 | 104 | plt.tight_layout() 105 | print('save figure at {}'.format({f'./figs/{dataset}_{method2name[method]}_conformal_generation_risk.jpg'})) 106 | plt.savefig(f'./figs/{dataset}_{method2name[method]}_conformal_generation risk.jpg',dpi=800) -------------------------------------------------------------------------------- /conformal_risk_computation/conformal_generation_risk_comparison.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import numpy as np 6 | from matplotlib.ticker import MaxNLocator 7 | from tqdm import tqdm 8 | from scipy.stats import binom 9 | from typing import List, Union, Optional, Tuple, Mapping, Dict 10 | import os 11 | from numpy import linalg 12 | import numpy as np 13 | import argparse 14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func, get_color_dict 15 | 16 | if __name__=='__main__': 17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.') 18 | 19 | parser.add_argument('--llm_retrieval_methods', nargs='+') 20 | parser.add_argument('--datasets', nargs='+') 21 | parser.add_argument('--num_gen', type=int, default=1) 22 | parser.add_argument('--gen_thresh', type=int, default=0) 23 | parser.add_argument('--n_rag_list', nargs='+') 24 | 25 | args = parser.parse_args() 26 | 27 | method2name = get_method2name() 28 | method2method = get_method2method() 29 | 30 | datasets = args.datasets 31 | retrieval_methods = args.llm_retrieval_methods 32 | 33 | num_point_dict, sample_size_dict = get_num_points() 34 | 35 | for dataset in datasets: 36 | num_gen = args.num_gen 37 | gen_thresh = args.gen_thresh 38 | n_rag_list = args.n_rag_list 39 | n_rag_list = [int(item) for item in n_rag_list] 40 | 41 | alphas = {} 42 | min_risk, max_risk = 10e5, -10e5 43 | 44 | for method in retrieval_methods: 45 | alpha_list = [] 46 | for n_rag in n_rag_list: 47 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{gen_thresh}_calibration_result.json' 48 | results = json.load(open(path, 'r', encoding='utf-8')) 49 | results = list(results.values())[0][0] 50 | results = np.array(results) 51 | 52 | alpha_list.append(empirical_risk2alpha_func(1. - np.mean(results), N_cal=len(results), delta=0.1)) 53 | if alpha_list[-1]max_risk: 56 | max_risk = alpha_list[-1] 57 | alphas[method] = alpha_list 58 | 59 | 60 | data = { 61 | 'N_rag': n_rag_list, 62 | } 63 | 64 | for idx in range(len(retrieval_methods)): 65 | data[retrieval_methods[idx]] = alphas[retrieval_methods[idx]] 66 | 67 | 68 | df = pd.DataFrame(data) 69 | 70 | # Set the style 71 | plt.style.use('seaborn-v0_8-darkgrid') 72 | plt.figure(figsize=(10,8)) 73 | 74 | color_dict = get_color_dict(retrieval_methods) 75 | 76 | for metric, color in color_dict.items(): 77 | plt.plot(df['N_rag'], df[metric], marker='*', color=color, label=method2name[metric], markersize=25) 78 | 79 | fontsize=45 80 | plt.ylim([min_risk-0.02, max_risk+0.02]) 81 | plt.xticks(fontsize=fontsize) 82 | plt.yticks(fontsize=fontsize) 83 | 84 | ax = plt.gca() 85 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 86 | 87 | plt.tight_layout() 88 | plt.savefig(f'./figs/{dataset}_conformal_risk.jpg',dpi=800) -------------------------------------------------------------------------------- /conformal_risk_computation/conformal_generation_risk_multi_dim_config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import numpy as np 6 | from matplotlib.ticker import MaxNLocator 7 | from tqdm import tqdm 8 | from scipy.stats import binom 9 | from typing import List, Union, Optional, Tuple, Mapping, Dict 10 | import os 11 | from numpy import linalg 12 | import numpy as np 13 | import argparse 14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func 15 | 16 | if __name__=='__main__': 17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.') 18 | 19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base']) 20 | parser.add_argument('--datasets', nargs='+') 21 | parser.add_argument('--gen_thresh', type=int, default=0) 22 | parser.add_argument('--n_rag_list', nargs='+') 23 | parser.add_argument('--lambda_g_list', nargs='+') 24 | 25 | args = parser.parse_args() 26 | 27 | method2name = get_method2name() 28 | method2method = get_method2method() 29 | 30 | datasets = args.datasets 31 | method = args.llm_retrievalmodel 32 | 33 | n_rag_list = np.array(args.n_rag_list) 34 | lambda_g_list = np.array(args.lambda_g_list) 35 | 36 | num_point_dict, sample_size_dict = get_num_points() 37 | 38 | for dataset in datasets: 39 | 40 | fig = plt.figure(figsize=(10, 8)) 41 | ax = plt.axes(projection='3d') 42 | 43 | xdata, ydata, zdata = [], [], [] 44 | c = [] 45 | random.seed(1) 46 | rand_list = [] 47 | 48 | num_list = [] 49 | for idx in range(len(lambda_g_list)): 50 | num_list.append(int(lambda_g_list[idx])) 51 | max_num_gen = max(num_list) 52 | 53 | 54 | for n_rag in n_rag_list: 55 | for num_gen in lambda_g_list: 56 | num_gen = int(num_gen) 57 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_{args.gen_thresh}_calibration_result.json' 58 | 59 | result = json.load(open(path, 'r', encoding='utf-8')) 60 | result = list(result.values())[0] 61 | result = np.array(result) 62 | rand_indx = random.sample(tuple(list(range(0, max_num_gen))), num_gen) 63 | 64 | result = result[rand_indx, :] 65 | rand_list.append(rand_indx) 66 | result = result.max(axis=0) 67 | results = 1. - result 68 | 69 | alpha_rag = empirical_risk2alpha_func(np.mean(results), N_cal=len(results), delta=0.1) 70 | 71 | xdata.append(n_rag) 72 | ydata.append(num_gen) 73 | zdata.append(alpha_rag) 74 | c.append('red') 75 | 76 | simulation_points_temp = [] 77 | num_simulation_point = num_point_dict[dataset] 78 | sample_size_per_point = sample_size_dict[dataset] 79 | for idx_point in range(num_simulation_point): 80 | test_list = random.sample(list(range(len(results))), sample_size_per_point) 81 | new_res = np.mean(results[test_list]) 82 | simulation_points_temp.append(new_res) 83 | simulation_points_temp = np.array(simulation_points_temp) 84 | 85 | ax.scatter3D([n_rag] * len(simulation_points_temp), [num_gen] * len(simulation_points_temp), simulation_points_temp, c='gray', alpha=0.3, s=25) 86 | 87 | ax.scatter3D(xdata, ydata, zdata, c=c, marker='^', s=80) 88 | 89 | labelsize = 20 90 | fontsize = 30 91 | ax.tick_params(axis='x', labelsize=labelsize,pad=-3) 92 | ax.tick_params(axis='y', labelsize=labelsize,pad=-3) 93 | ax.tick_params(axis='z', labelsize=labelsize,pad=8) 94 | ax.set_xlabel(r'$N_{rag}$', fontsize=fontsize, labelpad=10.0) 95 | ax.set_ylabel(r'$\lambda_g$', fontsize=fontsize, labelpad=10.0) 96 | ax.set_zlabel(r'Risk', fontsize=fontsize, labelpad=25.0, rotation=90) 97 | 98 | plt.tight_layout() 99 | plt.savefig(f'./figs/{dataset}_{method2method[method]}_nrag_lambdag_conformal_risk.jpg',dpi=800) -------------------------------------------------------------------------------- /conformal_risk_computation/conformal_generation_risk_num_gen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import numpy as np 6 | from matplotlib.ticker import MaxNLocator 7 | from tqdm import tqdm 8 | from scipy.stats import binom 9 | from typing import List, Union, Optional, Tuple, Mapping, Dict 10 | import os 11 | from numpy import linalg 12 | import numpy as np 13 | import argparse 14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func, compute_conformal_risk 15 | 16 | if __name__=='__main__': 17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.') 18 | 19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base']) 20 | parser.add_argument('--datasets', nargs='+') 21 | parser.add_argument('--n_rag', type=int, default=5) 22 | parser.add_argument('--gen_thresh', type=int, default=0) 23 | parser.add_argument('--num_gen_list', nargs='+') 24 | 25 | args = parser.parse_args() 26 | 27 | method2name = get_method2name() 28 | method2method = get_method2method() 29 | 30 | datasets = args.datasets 31 | method = args.llm_retrievalmodel 32 | 33 | num_point_dict, sample_size_dict = get_num_points() 34 | 35 | num_list = [] 36 | for idx in range(len(args.num_gen_list)): 37 | num_list.append(int(args.num_gen_list[idx])) 38 | max_num_gen = max(num_list) 39 | 40 | random.seed(1) 41 | 42 | for dataset in datasets: 43 | num_simulation_point = num_point_dict[dataset] 44 | sample_size_per_point = sample_size_dict[dataset] 45 | 46 | num_gen_list = args.num_gen_list 47 | num_gen_list = [int(item) for item in num_gen_list] 48 | gen_thresh = args.gen_thresh 49 | n_rag = args.n_rag 50 | 51 | alphas = {} 52 | min_risk, max_risk = 10e5, -10e5 53 | rand_list = [] 54 | alpha_list = [] 55 | for num_gen in num_gen_list: 56 | num_gen = int(num_gen) 57 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_{gen_thresh}_calibration_result.json' 58 | result = json.load(open(path, 'r', encoding='utf-8')) 59 | result = list(result.values())[0] 60 | result = np.array(result) 61 | rand_indx = random.sample(tuple(list(range(0,max_num_gen))), num_gen) 62 | result = result[rand_indx,:] 63 | rand_list.append(rand_indx) 64 | result = result.max(axis=0) 65 | results = 1. - result 66 | 67 | alpha_list.append(compute_conformal_risk(results)) 68 | if alpha_list[-1] < min_risk: 69 | min_risk = alpha_list[-1] 70 | if alpha_list[-1] > max_risk: 71 | max_risk = alpha_list[-1] 72 | alphas[method] = alpha_list 73 | 74 | data = { 75 | 'num_gen': num_gen_list, 76 | method: alphas[method], 77 | } 78 | 79 | df = pd.DataFrame(data) 80 | 81 | plt.style.use('seaborn-v0_8-darkgrid') 82 | plt.figure(figsize=(10, 8)) 83 | 84 | color_dict = {method: 'black'} 85 | 86 | for metric, color in color_dict.items(): 87 | plt.plot(df['num_gen'], df[metric], marker='^', color=color, label=r'Certified Conformal Risk $\alpha_{\text{rag}}$', markersize=20) 88 | 89 | simulation_points = [] 90 | 91 | for num_gen in num_gen_list: 92 | num_gen = int(num_gen) 93 | 94 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_{gen_thresh}_calibration_result.json' 95 | 96 | result = json.load(open(path, 'r', encoding='utf-8')) 97 | result = list(result.values())[0] 98 | result = np.array(result) 99 | result = result[rand_list[num_gen-1], :] 100 | result = result.max(axis=0) 101 | results = 1. - result 102 | 103 | simulation_points_temp = [] 104 | for idx_point in range(num_simulation_point): 105 | test_list = random.sample(list(range(len(results))), sample_size_per_point) 106 | res = np.mean(results[test_list]) 107 | simulation_points_temp.append(res) 108 | simulation_points_temp = np.array(simulation_points_temp) 109 | for x in simulation_points_temp: 110 | if x < min_risk: 111 | min_risk = x 112 | 113 | plt.scatter(x=[num_gen] * len(simulation_points_temp), y=simulation_points_temp, color='gray', alpha=0.7, s=100) 114 | 115 | fontsize = 45 116 | 117 | plt.ylim([min_risk - 0.02, max_risk + 0.02]) 118 | plt.xticks(fontsize=fontsize) 119 | plt.yticks(fontsize=fontsize) 120 | 121 | ax = plt.gca() 122 | # ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 123 | ax.yaxis.set_major_formatter('{x:3<0.2f}') 124 | 125 | plt.tight_layout() 126 | print('save figure at {}'.format({f'./figs/{dataset}_{method2name[method]}_multi_gen_conformal_bound_simulation.jpg'})) 127 | plt.savefig(f'./figs/{dataset}_{method2name[method]}_multi_gen_conformal_bound_simulation.jpg', dpi=800) -------------------------------------------------------------------------------- /conformal_risk_computation/conformal_generation_risk_similarity_threshold.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import numpy as np 6 | from matplotlib.ticker import MaxNLocator 7 | from tqdm import tqdm 8 | from scipy.stats import binom 9 | from typing import List, Union, Optional, Tuple, Mapping, Dict 10 | import os 11 | from numpy import linalg 12 | import numpy as np 13 | import argparse 14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func, compute_conformal_risk 15 | 16 | if __name__=='__main__': 17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.') 18 | 19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base']) 20 | parser.add_argument('--datasets', nargs='+') 21 | parser.add_argument('--num_gen', type=int, default=20) 22 | parser.add_argument('--n_rag', type=int, default=5) 23 | parser.add_argument('--thresh_list', nargs='+') 24 | 25 | args = parser.parse_args() 26 | 27 | method2name = get_method2name() 28 | method2method = get_method2method() 29 | 30 | datasets = args.datasets 31 | method = args.llm_retrievalmodel 32 | 33 | num_point_dict, sample_size_dict = get_num_points() 34 | 35 | random.seed(1) 36 | 37 | for dataset in datasets: 38 | num_simulation_point = num_point_dict[dataset] 39 | sample_size_per_point = sample_size_dict[dataset] 40 | 41 | thresh_list = args.thresh_list 42 | num_gen = args.num_gen 43 | n_rag = args.n_rag 44 | 45 | alphas = {} 46 | min_risk, max_risk = 10e5, -10e5 47 | rand_list = [] 48 | alpha_list = [] 49 | for idx, thresh in enumerate(thresh_list): 50 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{thresh}_calibration_result.json' 51 | 52 | result = json.load(open(path, 'r', encoding='utf-8')) 53 | result = list(result.values())[0] 54 | result = np.array(result) 55 | result = result.max(axis=0) 56 | results = 1. - result 57 | 58 | alpha_list.append(compute_conformal_risk(results)) 59 | if alpha_list[-1] < min_risk: 60 | min_risk = alpha_list[-1] 61 | if alpha_list[-1] > max_risk: 62 | max_risk = alpha_list[-1] 63 | alphas[method] = alpha_list 64 | 65 | data = { 66 | 'thresh': [ int(item) / 100. for item in thresh_list], 67 | method: alphas[method], 68 | } 69 | 70 | df = pd.DataFrame(data) 71 | plt.style.use('seaborn-v0_8-darkgrid') 72 | plt.figure(figsize=(10, 8)) 73 | 74 | color_dict = {method: 'black'} 75 | 76 | for metric, color in color_dict.items(): 77 | plt.plot(df['thresh'], df[metric], marker='^', color=color, label=r'Certified Conformal Risk $\alpha_{\text{rag}}$', markersize=20) 78 | 79 | simulation_points = [] 80 | for idx_,thresh in enumerate(thresh_list): 81 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{thresh}_calibration_result.json' 82 | result = json.load(open(path, 'r', encoding='utf-8')) 83 | result = list(result.values())[0] 84 | result = np.array(result) 85 | result = result.max(axis=0) 86 | results = 1. - result 87 | 88 | simulation_points_temp = [] 89 | for idx_point in range(num_simulation_point): 90 | test_list = random.sample(list(range(len(results))), sample_size_per_point) 91 | res = np.mean(results[test_list]) 92 | simulation_points_temp.append(res) 93 | simulation_points_temp = np.array(simulation_points_temp) 94 | for x in simulation_points_temp: 95 | if x < min_risk: 96 | min_risk = x 97 | plt.scatter(x=[int(thresh) / 100.0] * len(simulation_points_temp), y=simulation_points_temp, color='gray', alpha=0.7, s=100) 98 | 99 | fontsize = 32 100 | plt.ylim([min_risk - 0.02, max_risk + 0.02]) 101 | plt.xticks(fontsize=fontsize) 102 | plt.yticks(fontsize=fontsize) 103 | 104 | ax = plt.gca() 105 | # ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 106 | ax.yaxis.set_major_formatter('{x:3<0.2f}') 107 | 108 | plt.gca().invert_xaxis() 109 | plt.tight_layout() 110 | print('save figure at {}'.format({f'./figs/{dataset}_{method2name[method]}_multi_thresh_conformal_bound_simulation.jpg'})) 111 | plt.savefig(f'./figs/{dataset}_{method2name[method]}_multi_thresh_conformal_bound_simulation.jpg', dpi=800) -------------------------------------------------------------------------------- /conformal_risk_computation/conformal_generation_risk_valid_config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import numpy as np 6 | from matplotlib.ticker import MaxNLocator 7 | from tqdm import tqdm 8 | from scipy.stats import binom 9 | from typing import List, Union, Optional, Tuple, Mapping, Dict 10 | import os 11 | from numpy import linalg 12 | import numpy as np 13 | import argparse 14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func, compute_p_value, FWER 15 | 16 | if __name__=='__main__': 17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.') 18 | 19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base']) 20 | parser.add_argument('--datasets', nargs='+') 21 | parser.add_argument('--gen_thresh', type=int, default=0) 22 | parser.add_argument('--n_rag_list', nargs='+') 23 | parser.add_argument('--lambda_g_list', nargs='+') 24 | parser.add_argument('--alpha_desired', type=float, default=0.6) 25 | 26 | args = parser.parse_args() 27 | 28 | method2name = get_method2name() 29 | method2method = get_method2method() 30 | 31 | datasets = args.datasets 32 | method = args.llm_retrievalmodel 33 | 34 | n_rag_list = np.array(args.n_rag_list) 35 | lambda_g_list = np.array(args.lambda_g_list) 36 | 37 | alpha_desired = args.alpha_desired 38 | delta = 0.1 39 | 40 | num_point_dict, sample_size_dict = get_num_points() 41 | 42 | num_list = [] 43 | for idx in range(len(lambda_g_list)): 44 | num_list.append(int(lambda_g_list[idx])) 45 | max_num_gen = max(num_list) 46 | 47 | for dataset in datasets: 48 | fig = plt.figure(figsize=(10, 8)) 49 | ax = plt.axes(projection='3d') 50 | 51 | xdata = [] 52 | ydata = [] 53 | zdata = [] 54 | c = [] 55 | random.seed(1) 56 | rand_list = [] 57 | p_values = {} 58 | for n_rag in n_rag_list: 59 | for num_gen in lambda_g_list: 60 | num_gen = int(num_gen) 61 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_0_calibration_result.json' 62 | 63 | result = json.load(open(path, 'r', encoding='utf-8')) 64 | result = list(result.values())[0] 65 | result = np.array(result) 66 | rand_indx = random.sample(tuple(list(range(0, max_num_gen))), num_gen) 67 | result = result[rand_indx, :] 68 | rand_list.append(rand_indx) 69 | result = result.max(axis=0) 70 | results = 1. - result 71 | 72 | p_ = compute_p_value(alpha_desired, np.mean(results), len(results)) 73 | p_values[(n_rag, num_gen)] = p_ 74 | 75 | valid_or_not = FWER(p_values, delta) 76 | 77 | for n_rag in n_rag_list: 78 | for num_gen in lambda_g_list: 79 | num_gen = int(num_gen) 80 | if (not valid_or_not[(n_rag, num_gen)]): 81 | continue 82 | 83 | ax.scatter3D([n_rag], [num_gen], 0.01, c='green', alpha=0.9, s=200) 84 | 85 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_0_calibration_result.json' 86 | 87 | result = json.load(open(path, 'r', encoding='utf-8')) 88 | result = list(result.values())[0] 89 | result = np.array(result) 90 | rand_indx = random.sample(tuple(list(range(0, max_num_gen))), num_gen) 91 | result = result[rand_indx, :] 92 | rand_list.append(rand_indx) 93 | result = result.max(axis=0) 94 | results = 1. - result 95 | 96 | simulation_points_temp = [] 97 | num_simulation_point = num_point_dict[dataset] 98 | sample_size_per_point = sample_size_dict[dataset] 99 | for idx_point in range(num_simulation_point): 100 | test_list = random.sample(list(range(len(results))), sample_size_per_point) 101 | new_res = np.mean(results[test_list]) 102 | simulation_points_temp.append(new_res) 103 | simulation_points_temp = np.array(simulation_points_temp) 104 | ax.scatter3D([n_rag] * len(simulation_points_temp), [num_gen] * len(simulation_points_temp), simulation_points_temp, c='gray', alpha=0.3, s=25) 105 | 106 | 107 | n_rag_list = [int(item) for item in n_rag_list] 108 | lambda_g_list = [int(item) for item in lambda_g_list] 109 | 110 | N_rag_grid, lambda_g_grid = np.meshgrid(n_rag_list, lambda_g_list) 111 | risk_values = np.ones_like(N_rag_grid, dtype=float) * alpha_desired 112 | 113 | 114 | 115 | surf = ax.plot_surface(N_rag_grid, lambda_g_grid, risk_values, color='red', edgecolor='none', alpha=0.2) 116 | 117 | labelsize = 20 118 | fontsize = 30 119 | ax.tick_params(axis='x', labelsize=labelsize,pad=-3) 120 | ax.tick_params(axis='y', labelsize=labelsize,pad=-3) 121 | ax.tick_params(axis='z', labelsize=labelsize,pad=8) 122 | ax.set_xlabel(r'$N_{rag}$', fontsize=fontsize, labelpad=10.0) 123 | ax.set_ylabel(r'$\lambda_g$', fontsize=fontsize, labelpad=10.0) 124 | ax.set_zlabel(r'Risk', fontsize=fontsize, labelpad=25.0, rotation=90) 125 | 126 | plt.tight_layout() 127 | print(f'save fig at ./figs/{dataset}_{method2name[method]}_valid_config.jpg') 128 | plt.savefig(f'./figs/{dataset}_{method2name[method]}_valid_config.jpg',dpi=800) -------------------------------------------------------------------------------- /conformal_risk_computation/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import numpy as np 4 | from matplotlib.ticker import MaxNLocator 5 | from tqdm import tqdm 6 | from scipy.stats import binom 7 | from typing import List, Union, Optional, Tuple, Mapping, Dict 8 | import os 9 | import json 10 | from scipy.optimize import fsolve 11 | 12 | def save_json_to_file(objects: Union[List, dict], path: str, line_by_line: bool = False): 13 | if line_by_line: 14 | assert isinstance(objects, list), 'Only list can be saved in line by line format' 15 | 16 | os.makedirs(os.path.dirname(path), exist_ok=True) 17 | with open(path, 'w', encoding='utf-8') as writer: 18 | if not line_by_line: 19 | json.dump(objects, writer, ensure_ascii=False, indent=4, separators=(',', ':')) 20 | else: 21 | for obj in objects: 22 | writer.write(json.dumps(obj, ensure_ascii=False, separators=(',', ':'))) 23 | writer.write('\n') 24 | 25 | def get_method2name(): 26 | method2name = {} 27 | method2name['llama-7b_BM25'] = 'BM25' 28 | method2name['llama-7b_triever-base'] = 'LLM-R' 29 | method2name['llama-7b_OpenAI'] = 'openai' 30 | method2name['llama-7b_BAAI'] = 'baai' 31 | return method2name 32 | 33 | 34 | def get_method2method(): 35 | method2method = {} 36 | method2method['llama-7b_triever-base'] = 'llm-retriever-base' 37 | method2method['llama-7b_BM25'] = 'bm25' 38 | method2method['llama-7b_OpenAI'] = 'openai' 39 | method2method['llama-7b_BAAI'] = 'baai' 40 | return method2method 41 | 42 | def get_num_points(): 43 | num_point_dict = {'aeslc': 100, 'common_gen': 100, 'dart': 100, 'e2e_nlg': 100} 44 | sample_size_dict = {'aeslc': 400, 'common_gen': 400, 'dart': 400, 'e2e_nlg': 400} 45 | return num_point_dict, sample_size_dict 46 | 47 | def get_color_dict(retrieval_methods): 48 | colors = {'llama-7b_BM25': 'royalblue', 'llama-7b_triever-base': 'lightcoral', 'llama-7b_OpenAI': 'darkviolet', 'llama-7b_BAAI': 'olivedrab'} 49 | color_dict = {} 50 | for method in retrieval_methods: 51 | color_dict[method] = colors[method] 52 | return color_dict 53 | 54 | def get_max_x_all(): 55 | max_x_all = {'aeslc': 11, 'common_gen': 21, 'dart': 21, 'e2e_nlg': 21} 56 | return max_x_all 57 | 58 | def get_num_points_distribution_shift(): 59 | num_point_dict = {'aeslc': 100, 'common_gen': 100, 'dart': 100, 'e2e_nlg': 100} 60 | sample_size_dict = {'aeslc': 30, 'common_gen': 30, 'dart': 30, 'e2e_nlg': 30} 61 | simulate_num = 30000 62 | return num_point_dict, sample_size_dict, simulate_num 63 | 64 | def h1(a,b): 65 | return a * np.log(a/b) + (1-a) * np.log((1-a)/(1-b)) 66 | 67 | def solve_inverse_1_func(r_hat, N_cal, delta): 68 | def func(a): 69 | ret = np.exp(-N_cal * h1(min(a,r_hat), a)) - delta 70 | return ret 71 | root = fsolve(func, [r_hat+0.02])[0] 72 | return root 73 | 74 | def solve_inverse_2_func(r_hat, N_cal, delta): 75 | def func(a): 76 | ret = binom.cdf(np.ceil(N_cal * r_hat), N_cal, a) - delta / np.exp(1) 77 | return ret 78 | root = fsolve(func, [r_hat+0.02])[0] 79 | return root 80 | 81 | def empirical_risk2alpha_func(r_hat, N_cal, delta): 82 | delta = delta 83 | alpha_term_1 = solve_inverse_1_func(r_hat, N_cal, delta) 84 | alpha_term_2 = solve_inverse_2_func(r_hat, N_cal, delta) 85 | alpha = min(alpha_term_1, alpha_term_2) 86 | if alpha>1.0: 87 | alpha=1.0 88 | return alpha 89 | 90 | def cal_dist_shift_bound(r_hat, dist, N, var, delta): 91 | r_hat_overline = r_hat + dist**2 * (2-dist**2) * (1-r_hat) + 2 * dist * (1-dist**2) * np.sqrt(2-dist**2) * np.sqrt(var) 92 | # print(f'r_hat_overline 1: {r_hat_overline}') 93 | 94 | # finite-sample error 95 | r_hat_overline += (1-dist**2) * ((1-dist**2) / np.sqrt(2*N) + 2 * dist * np.sqrt(2 * (2 - dist**2)) / np.sqrt(N-1)) * np.sqrt(np.log(4/delta)) + np.sqrt(np.log(8/delta) / 2 / N) 96 | 97 | # print(f'r_hat_overline 2: {r_hat_overline}') 98 | # risk_shift = empirical_risk2alpha_newton(r_hat_overline, N, delta) 99 | risk_shift = empirical_risk2alpha_func(r_hat_overline, N, delta) 100 | return risk_shift 101 | 102 | def compute_hellinger_distance(vec1, vec2): 103 | dist = 1. 104 | for idx in range(len(vec1)): 105 | dist -= np.sqrt(vec1[idx]) * np.sqrt(vec2[idx]) 106 | dist = np.sqrt(dist) 107 | return dist 108 | 109 | def compute_p_value(alpha_desired, r_hat, N_cal): 110 | term1 = np.exp(-N_cal * h1(min(alpha_desired,r_hat), alpha_desired)) 111 | term2 = binom.cdf(np.ceil(N_cal * r_hat), N_cal, alpha_desired) * np.exp(1) 112 | p = min(term1, term2) 113 | if r_hat > alpha_desired: 114 | return 1.0 115 | return p 116 | 117 | def FWER(p_values, delta): 118 | keys = p_values.keys() 119 | n_tol = len(keys) 120 | valid_or_not = {} 121 | # print(p_values) 122 | for key in p_values.keys(): 123 | if p_values[key] < delta / n_tol: 124 | valid_or_not[key] = True 125 | else: 126 | valid_or_not[key] = False 127 | return valid_or_not 128 | 129 | def compute_conformal_risk(results): 130 | return empirical_risk2alpha_func(np.mean(results), len(results), delta=0.1) -------------------------------------------------------------------------------- /ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 10, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "optimizer": { 12 | "type": "AdamW", 13 | "params": { 14 | "lr": "auto", 15 | "betas": "auto", 16 | "eps": "auto", 17 | "weight_decay": "auto" 18 | } 19 | }, 20 | 21 | "scheduler": { 22 | "type": "WarmupDecayLR", 23 | "params": { 24 | "warmup_min_lr": "auto", 25 | "warmup_max_lr": "auto", 26 | "warmup_num_steps": "auto", 27 | "total_num_steps": "auto" 28 | } 29 | }, 30 | 31 | "zero_optimization": { 32 | "stage": 2, 33 | "allgather_partitions": true, 34 | "allgather_bucket_size": 2e8, 35 | "overlap_comm": true, 36 | "reduce_scatter": true, 37 | "reduce_bucket_size": 2e8, 38 | "contiguous_gradients": true 39 | }, 40 | 41 | "gradient_accumulation_steps": "auto", 42 | "gradient_clipping": "auto", 43 | "steps_per_print": 2000, 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "wall_clock_breakdown": false 47 | } 48 | -------------------------------------------------------------------------------- /figs/README.md: -------------------------------------------------------------------------------- 1 | The generated images will be stored in this directory. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.28 2 | datasets==2.3.0 3 | deepspeed==0.8.3 4 | tqdm 5 | rouge 6 | faiss-cpu 7 | matplotlib 8 | pandas 9 | scipy 10 | gdown 11 | rank_bm25 12 | tiktoken 13 | openai 14 | scikit-learn 15 | absl-py -------------------------------------------------------------------------------- /run/conformal_parallel.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | set -e 5 | 6 | DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )" 7 | echo "working directory: ${DIR}" 8 | 9 | MODEL_NAME_OR_PATH="random" 10 | if [[ $# -ge 1 && ! "$1" == "--"* ]]; then 11 | MODEL_NAME_OR_PATH=$1 12 | shift 13 | fi 14 | 15 | LLM_MODEL_NAME_OR_PATH="huggyllama/llama-7b" 16 | if [[ $# -ge 1 && ! "$1" == "--"* ]]; then 17 | LLM_MODEL_NAME_OR_PATH=$1 18 | shift 19 | fi 20 | 21 | if [ -z "$OUTPUT_DIR" ]; then 22 | OUTPUT_DIR="${MODEL_NAME_OR_PATH}" 23 | fi 24 | if [ -z "$DATA_DIR" ]; then 25 | DATA_DIR="${DIR}/data/tasks/" 26 | fi 27 | 28 | 29 | EVAL_TASKS=('aeslc' 'common_gen' 'dart' 'e2e_nlg') 30 | 31 | # generate data: input prompt (retrieved examples) and query for the specified task 32 | PYTHONPATH=src/ python -u src/inference/generate_few_shot_prompt.py \ 33 | --model_name_or_path "${MODEL_NAME_OR_PATH}" \ 34 | --seed 1234 \ 35 | --fp16 \ 36 | --llm_eval_tasks "${EVAL_TASKS[@]}" \ 37 | --llm_eval_split test \ 38 | --llm_k_shot "${N_rag}" \ 39 | --output_dir "${OUTPUT_DIR}" \ 40 | --data_dir "${DATA_DIR}" 41 | 42 | 43 | nproc=$(echo -n "$GPUs" | wc -m) 44 | nproc=`expr $nproc + 1` 45 | nproc=`expr $nproc / 2` 46 | 47 | # evaluate the empirical risks 48 | CUDA_VISIBLE_DEVICES="${GPUs}" torchrun --nproc_per_node "${nproc}" --master_port=38765 src/conformal_calibration_empirical_risk.py \ 49 | --llm_batch_size_per_device 4 \ 50 | --llm_k_shot "${N_rag}" \ 51 | --llm_num_gen "${llm_num_gen}" \ 52 | --llm_gen_threshold "${llm_gen_threshold}" \ 53 | --cali_ratio 0.5 \ 54 | --model_name_or_path "${MODEL_NAME_OR_PATH}" \ 55 | --seed 1234 \ 56 | --fp16 \ 57 | --do_llm_eval \ 58 | --llm_model_name_or_path "${LLM_MODEL_NAME_OR_PATH}" \ 59 | --llm_eval_tasks "${EVAL_TASKS[@]}" \ 60 | --llm_eval_split test \ 61 | --llm_max_input_length 1024 \ 62 | --llm_max_decode_length 64 \ 63 | --output_dir "${OUTPUT_DIR}" \ 64 | --data_dir "${DATA_DIR}" \ 65 | --overwrite_output_dir \ 66 | --disable_tqdm True \ 67 | --report_to none "$@" -------------------------------------------------------------------------------- /scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk.sh: -------------------------------------------------------------------------------- 1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base 2 | do 3 | python conformal_risk_computation/conformal_distribution_shift_generation_risk.py --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg --n_rag 15 --num_gen 1 --gen_thresh 0 4 | done -------------------------------------------------------------------------------- /scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk_comparisons.sh: -------------------------------------------------------------------------------- 1 | python conformal_risk_computation/conformal_distribution_shift_generation_risk_comparison.py \ 2 | --llm_retrieval_methods llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base \ 3 | --n_rag 15 --num_gen 1 --gen_thresh 0 --datasets aeslc common_gen dart e2e_nlg -------------------------------------------------------------------------------- /scripts_conformal_generation_risk/run_conformal_generation_risk.sh: -------------------------------------------------------------------------------- 1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base 2 | do 3 | python conformal_risk_computation/conformal_generation_risk.py --llm_retrievalmodel $method \ 4 | --datasets aeslc common_gen dart e2e_nlg \ 5 | --num_gen 1 --gen_thresh 0 --n_rag_list 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 6 | done -------------------------------------------------------------------------------- /scripts_conformal_generation_risk/run_conformal_generation_risk_comparisons.sh: -------------------------------------------------------------------------------- 1 | python conformal_risk_computation/conformal_generation_risk_comparison.py \ 2 | --llm_retrieval_methods llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base --datasets aeslc common_gen dart e2e_nlg \ 3 | --num_gen 1 --gen_thresh 0 --n_rag_list 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 -------------------------------------------------------------------------------- /scripts_conformal_generation_risk/run_conformal_generation_risk_multi_dim_config.sh: -------------------------------------------------------------------------------- 1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base 2 | do 3 | python conformal_risk_computation/conformal_generation_risk_multi_dim_config.py \ 4 | --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg --n_rag_list 0 1 2 3 4 5 6 7 8 9 \ 5 | --lambda_g_list 1 2 3 4 5 6 7 8 9 10 --gen_thresh 0 6 | done -------------------------------------------------------------------------------- /scripts_conformal_generation_risk/run_conformal_generation_risk_num_gen.sh: -------------------------------------------------------------------------------- 1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base 2 | do 3 | python conformal_risk_computation/conformal_generation_risk_num_gen.py \ 4 | --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg \ 5 | --n_rag 5 --gen_thresh 0 --num_gen_list 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 6 | done -------------------------------------------------------------------------------- /scripts_conformal_generation_risk/run_conformal_generation_risk_similarity_threshold.sh: -------------------------------------------------------------------------------- 1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base 2 | do 3 | python conformal_risk_computation/conformal_generation_risk_similarity_threshold.py \ 4 | --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg --n_rag 5 --num_gen 20 --thresh_list 10 9 8 7 6 5 | done -------------------------------------------------------------------------------- /scripts_conformal_generation_risk/run_conformal_generation_risk_valid_config.sh: -------------------------------------------------------------------------------- 1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base 2 | do 3 | python conformal_risk_computation/conformal_generation_risk_valid_config.py \ 4 | --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg --n_rag_list 0 1 2 3 4 5 6 7 8 9 --lambda_g_list 1 2 3 4 5 6 7 8 9 10 5 | done -------------------------------------------------------------------------------- /scripts_raw_risk_scores/baai.sh: -------------------------------------------------------------------------------- 1 | # OUTPUT_DIR: results stored at the dir 2 | # N_rag: number of retrieved examples 3 | # llm_num_gen: number of generations 4 | # llm_gen_threshold: accept similarity threshold during generations (Note: it is scaled by 100 and represented by an integer) 5 | 6 | for n in 0 1 2 3 4 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 7 | do 8 | for ngen in 1 9 | do 10 | GPUs=0,1,2,3 OUTPUT_DIR=outputs/baai N_rag=$n llm_num_gen=$ngen llm_gen_threshold=0 bash run/conformal_parallel.sh BAAI 11 | done 12 | done -------------------------------------------------------------------------------- /scripts_raw_risk_scores/bm25.sh: -------------------------------------------------------------------------------- 1 | # OUTPUT_DIR: results stored at the dir 2 | # N_rag: number of retrieved examples 3 | # llm_num_gen: number of generations 4 | # llm_gen_threshold: accept similarity threshold during generations (Note: it is scaled by 100 and represented by an integer) 5 | 6 | for n in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 7 | do 8 | for ngen in 1 9 | do 10 | GPUs=0,1,2,3 OUTPUT_DIR=outputs/bm25 N_rag=$n llm_num_gen=$ngen llm_gen_threshold=0 bash run/conformal_parallel.sh BM25 11 | done 12 | done -------------------------------------------------------------------------------- /scripts_raw_risk_scores/llm-r.sh: -------------------------------------------------------------------------------- 1 | # OUTPUT_DIR: results stored at the dir 2 | # N_rag: number of retrieved examples 3 | # llm_num_gen: number of generations 4 | # llm_gen_threshold: accept similarity threshold during generations (Note: it is scaled by 100 and represented by an integer) 5 | 6 | # path_retrieval_model specifies path of locally trained retrieval model 7 | # Huggingface transformer downloads or loads from cache from ~/.cache/huggingface/hub/models--llm_dir--llm_name (including blobs,refs,snapshots) 8 | 9 | for n in 0 1 2 3 # 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 10 | do 11 | for ngen in 1 12 | do 13 | GPUs=0,1,2,3 OUTPUT_DIR=outputs/llm-retriever-base N_rag=$n llm_num_gen=$ngen llm_gen_threshold=0 bash run/conformal_parallel.sh trained_retrieval_models/LLM-R/llm-retriever-base 14 | done 15 | done -------------------------------------------------------------------------------- /scripts_raw_risk_scores/openai.sh: -------------------------------------------------------------------------------- 1 | # OUTPUT_DIR: results stored at the dir 2 | # N_rag: number of retrieved examples 3 | # llm_num_gen: number of generations 4 | # llm_gen_threshold: accept similarity threshold during generations (Note: it is scaled by 100 and represented by an integer) 5 | 6 | for n in 0 1 2 3 4 6 7 8 9 # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 7 | do 8 | for ngen in 1 9 | do 10 | GPUs=0,1,2,3 OUTPUT_DIR=outputs/openai N_rag=$n llm_num_gen=$ngen llm_gen_threshold=0 bash run/conformal_parallel.sh OpenAI 11 | done 12 | done -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kangmintong/C-RAG/5bda4382bebf6453b3b8dccb8a52f4d015e95fee/src/.DS_Store -------------------------------------------------------------------------------- /src/collators/__init__.py: -------------------------------------------------------------------------------- 1 | from .biencoder_collator import BiencoderCollator 2 | from .cross_encoder_collator import CrossEncoderCollator 3 | -------------------------------------------------------------------------------- /src/collators/biencoder_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Dict, Any, Union, Optional 5 | from transformers import BatchEncoding, PreTrainedTokenizerBase 6 | from transformers.file_utils import PaddingStrategy 7 | 8 | from config import Arguments 9 | 10 | 11 | @dataclass 12 | class BiencoderCollator: 13 | 14 | args: Arguments 15 | tokenizer: PreTrainedTokenizerBase 16 | padding: Union[bool, str, PaddingStrategy] = True 17 | max_length: Optional[int] = None 18 | pad_to_multiple_of: Optional[int] = None 19 | return_tensors: str = "pt" 20 | 21 | def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding: 22 | queries: List[str] = [f['query'] for f in features] 23 | passages: List[str] = sum([f['passages'] for f in features], []) 24 | 25 | input_texts = queries + passages 26 | 27 | merged_batch_dict = self.tokenizer( 28 | input_texts, 29 | max_length=self.args.max_len, 30 | truncation=True, 31 | padding=self.padding, 32 | return_token_type_ids=False, 33 | pad_to_multiple_of=self.pad_to_multiple_of, 34 | return_tensors=self.return_tensors) 35 | 36 | # dummy placeholder for field "labels", won't use it to compute loss 37 | labels = torch.zeros(len(queries), dtype=torch.long) 38 | merged_batch_dict['labels'] = labels 39 | 40 | if 'kd_labels' in features[0]: 41 | kd_labels = torch.stack([torch.tensor(f['kd_labels']) for f in features], dim=0).float() 42 | merged_batch_dict['kd_labels'] = kd_labels 43 | return merged_batch_dict 44 | -------------------------------------------------------------------------------- /src/collators/cross_encoder_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Dict, Any 5 | from transformers import BatchEncoding, DataCollatorWithPadding 6 | 7 | 8 | @dataclass 9 | class CrossEncoderCollator(DataCollatorWithPadding): 10 | 11 | def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding: 12 | unpack_features = [] 13 | for ex in features: 14 | keys = list(ex.keys()) 15 | for idx in range(len(ex[keys[0]])): 16 | unpack_features.append({k: ex[k][idx] for k in keys}) 17 | 18 | collated_batch_dict = self.tokenizer.pad( 19 | unpack_features, 20 | padding=self.padding, 21 | pad_to_multiple_of=self.pad_to_multiple_of, 22 | return_tensors=self.return_tensors) 23 | 24 | collated_batch_dict['labels'] = torch.zeros(len(features), dtype=torch.long) 25 | return collated_batch_dict 26 | -------------------------------------------------------------------------------- /src/collators/gpt2_collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Dict, Any, Union, Optional 5 | from transformers import BatchEncoding, PreTrainedTokenizerBase 6 | from transformers.file_utils import PaddingStrategy 7 | 8 | from logger_config import logger 9 | 10 | 11 | @dataclass 12 | class ScoreCollator: 13 | tokenizer: PreTrainedTokenizerBase 14 | padding: Union[bool, str, PaddingStrategy] = True 15 | max_length: Optional[int] = None 16 | pad_to_multiple_of: Optional[int] = None 17 | return_tensors: str = "pt" 18 | delimiter: str = '\n' 19 | 20 | def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding: 21 | self.tokenizer.padding_side = 'right' 22 | input_texts = [f['input_texts'] for f in features] 23 | output_texts = [f['output_texts'] for f in features] 24 | assert all(not text.endswith(self.delimiter) for text in input_texts) 25 | assert all(not text.startswith(self.delimiter) for text in output_texts) 26 | concat_texts: List[str] = [self.delimiter.join([inp, out]) for inp, out in zip(input_texts, output_texts)] 27 | 28 | batch_dict = self.tokenizer( 29 | concat_texts, 30 | max_length=self.max_length, 31 | truncation=True, 32 | padding=self.padding, 33 | pad_to_multiple_of=self.pad_to_multiple_of, 34 | return_tensors=self.return_tensors) 35 | 36 | labels = batch_dict['input_ids'].clone() 37 | if self.tokenizer.pad_token_id is not None: 38 | labels[labels == self.tokenizer.pad_token_id] = -100 39 | num_valid_tokens = torch.cumsum(batch_dict['attention_mask'], dim=1) 40 | output_lengths: torch.LongTensor = torch.LongTensor(self._get_output_lengths(output_texts)) 41 | logger.debug('output lengths: {}'.format(output_lengths)) 42 | input_lengths: torch.LongTensor = torch.sum(batch_dict['attention_mask'], dim=1) - output_lengths 43 | labels[num_valid_tokens <= input_lengths[:, None]] = -100 44 | batch_dict['labels'] = labels 45 | 46 | return batch_dict 47 | 48 | def _get_output_lengths(self, output_texts: List[str]) -> List[int]: 49 | output_ids: List[List[int]] = self.tokenizer( 50 | output_texts, max_length=self.max_length, truncation=True, padding=False 51 | )['input_ids'] 52 | 53 | for idx in range(len(output_ids)): 54 | # llama tokenizer prepend a bos token 55 | if output_ids[idx][0] == self.tokenizer.bos_token_id: 56 | output_ids[idx] = output_ids[idx][1:] 57 | 58 | lengths: List[int] = [len(output_id) for output_id in output_ids] 59 | assert all(length > 0 for length in lengths), lengths 60 | 61 | return lengths 62 | 63 | 64 | @dataclass 65 | class DecodeCollator: 66 | tokenizer: PreTrainedTokenizerBase 67 | padding: Union[bool, str, PaddingStrategy] = True 68 | max_length: Optional[int] = None 69 | pad_to_multiple_of: Optional[int] = None 70 | return_tensors: str = "pt" 71 | 72 | def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding: 73 | # batch_score requires right padding, but generate requires left padding 74 | self.tokenizer.padding_side = 'left' 75 | input_texts = [f['input_texts'] for f in features] 76 | 77 | batch_dict = self.tokenizer( 78 | input_texts, 79 | max_length=self.max_length, 80 | truncation=True, 81 | padding=self.padding, 82 | pad_to_multiple_of=self.pad_to_multiple_of, 83 | return_tensors=self.return_tensors) 84 | return batch_dict 85 | -------------------------------------------------------------------------------- /src/compute_controlled_risk.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import numpy as np 3 | 4 | def binomcdf(p,n): 5 | p = 0.5 6 | n = 100 7 | x = 0 8 | for a in range(10): 9 | print(scipy.stats.binom.cdf(x, n, p)) 10 | x += 10 11 | 12 | delta = 0.1 13 | e = 2.718 14 | 15 | target_prob = delta / e 16 | 17 | N= 1750 18 | Score_list = [0.0576 ,0.1236, 0.2412, 0.254, 0.2643, 0.2679, 0.2596, 0.272, 0.2721, 0.2729, 0.2702, 0.2732, 0.272] 19 | R_hat_list = np.array([1.]*len(Score_list)) - np.array(Score_list) 20 | alpha_list = [] 21 | for R_hat in R_hat_list: 22 | t = np.ceil(N*R_hat) 23 | res = 0. 24 | res_gap = 1e10 25 | for ind in range(1000): 26 | alpha = ind * 0.001 27 | cdf = scipy.stats.binom.cdf(t, N, alpha) 28 | if abs(cdf-target_prob) Dataset: 14 | assert path.endswith('.jsonl') or path.endswith('.jsonl.gz') 15 | 16 | # two fields: id, contents 17 | corpus = load_dataset('json', data_files=path)['train'] 18 | logger.info('Load {} documents from {} with columns {}'.format(len(corpus), path, corpus.column_names)) 19 | logger.info('A random document: {}'.format(random.choice(corpus))) 20 | return corpus 21 | 22 | 23 | def to_positive_negative_format(example: Dict, topk_as_positive: int = 1, bottomk_as_negative: int = -1) -> Dict: 24 | # query_id / query / answers / doc_ids / doc_scores 25 | assert len(example['doc_ids']) == len(example['doc_scores']) 26 | sorted_indices: List[int] = np.argsort(example['doc_scores'])[::-1] 27 | positive_indices: List[int] = sorted_indices[:topk_as_positive] 28 | negative_indices: List[int] = sorted_indices[topk_as_positive:] if bottomk_as_negative <= 0 else sorted_indices[-bottomk_as_negative:] 29 | negative_indices = [idx for idx in negative_indices if idx not in positive_indices] 30 | np.random.shuffle(positive_indices) 31 | np.random.shuffle(negative_indices) 32 | return { 33 | 'positives': { 34 | 'doc_id': [example['doc_ids'][idx] for idx in positive_indices], 35 | 'score': [example['doc_scores'][idx] for idx in positive_indices], 36 | }, 37 | 'negatives': { 38 | 'doc_id': [example['doc_ids'][idx] for idx in negative_indices], 39 | 'score': [example['doc_scores'][idx] for idx in negative_indices], 40 | }, 41 | } 42 | 43 | 44 | def save_to_readable_format(in_path: str, corpus: Dataset, shuffle: bool = False, max_num_samples: int = 10000): 45 | out_path = '{}/readable_{}'.format(os.path.dirname(in_path), os.path.basename(in_path)) 46 | out_path = out_path.replace('.jsonl.gz', '.json') 47 | dataset: Dataset = load_dataset('json', data_files=in_path, split='train', download_mode=DownloadMode.FORCE_REDOWNLOAD) 48 | if shuffle: 49 | dataset = dataset.shuffle() 50 | if len(dataset) > max_num_samples: 51 | dataset = dataset.select(range(max_num_samples)) 52 | dataset = dataset.map( 53 | to_positive_negative_format, 54 | remove_columns=['doc_ids', 'doc_scores'], 55 | desc='to positive negative format' 56 | ) 57 | 58 | max_to_keep = 5 59 | 60 | def _create_readable_field(samples: Dict[str, List]) -> List: 61 | readable_ex = [] 62 | for idx in range(min(len(samples['doc_id']), max_to_keep)): 63 | doc_id = samples['doc_id'][idx] 64 | readable_ex.append({'doc_id': doc_id, 65 | 'contents': corpus[int(doc_id)]['contents'], 66 | 'score': samples['score'][idx], 67 | 'task_name': corpus[int(doc_id)]['task_name'], 68 | }) 69 | return readable_ex 70 | 71 | def _mp_func(ex: Dict) -> Dict: 72 | ex['positives'] = _create_readable_field(ex['positives']) 73 | ex['negatives'] = _create_readable_field(ex['negatives']) 74 | return ex 75 | dataset = dataset.map(_mp_func, desc='to readable format') 76 | 77 | dataset.to_json(out_path, force_ascii=False, lines=False, indent=4) 78 | logger.info('Done convert {} to readable format in {}'.format(in_path, out_path)) 79 | 80 | 81 | def save_llm_decoding_results( 82 | out_path: str, 83 | input_texts: List[str], 84 | decoded_texts: List[str], 85 | parsed_decoded_texts: List[str], 86 | options_list: List[List[str]], 87 | answer_texts: List[str]): 88 | assert len(input_texts) == len(decoded_texts) 89 | dataset = Dataset.from_dict({ 90 | 'input_text': input_texts, 91 | 'decoded_text': decoded_texts, 92 | 'parsed_decoded_text': parsed_decoded_texts, 93 | 'options': options_list, 94 | 'answer_text': answer_texts 95 | }) 96 | save_dataset(dataset, out_path, shuffle=True) 97 | logger.info('Successfully save decoding results to {}'.format(out_path)) 98 | 99 | 100 | def log_task_statistics(ds: Dataset, split: str = 'train'): 101 | task_name_counter = Counter() 102 | for task_name in ds['task_name']: 103 | task_name_counter[task_name] += 1 104 | # log the number of examples per task 105 | for task_name, count in task_name_counter.most_common(): 106 | logger.info('{} ({}): {}'.format(task_name, split, count)) 107 | logger.info('{}: {} tasks, {} examples in total'.format(split, len(task_name_counter), len(ds))) 108 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_eval import BaseEval 2 | from .random_eval import RandomEval 3 | from .dense_eval import DenseEval 4 | from .bm25_eval import BM25Eval 5 | from .openai_eval import OpenaiEval 6 | from .baai_eval import BAAIEval 7 | -------------------------------------------------------------------------------- /src/evaluation/baai_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Dict, Tuple 3 | from datasets import Dataset 4 | from evaluation.base_eval import BaseEval 5 | from config import Arguments 6 | from logger_config import logger 7 | import tiktoken 8 | from tqdm import tqdm 9 | import numpy as np 10 | import functools 11 | import signal 12 | import time 13 | import numpy as np 14 | from tqdm import tqdm 15 | import os 16 | import logging 17 | from numpy.linalg import norm 18 | from transformers import AutoTokenizer, AutoModel 19 | import torch 20 | 21 | class BAAIEval(BaseEval): 22 | 23 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs): 24 | super().__init__(args, corpus, **kwargs) 25 | 26 | self.corpus = corpus 27 | self.cache_file_dir = 'outputs/baai/' 28 | 29 | def get_baai_embeddings(self, cur_corpus, task_name): 30 | out_path = os.path.join(self.cache_file_dir,f'{task_name}_baai_embedding.npy') 31 | if os.path.exists(out_path): 32 | logger.info(f'{out_path} exists, directly loading it') 33 | embeddings = np.load(out_path) 34 | else: 35 | tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5') 36 | model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5') 37 | model.eval() 38 | 39 | embeddings = [] 40 | with torch.no_grad(): 41 | for entry in tqdm(cur_corpus): 42 | doc = [entry['contents']] 43 | encoded_input = tokenizer(doc, padding=True, truncation=True, return_tensors='pt') 44 | model_output = model(**encoded_input) 45 | embedding = model_output[0][:, 0] 46 | embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) 47 | embeddings.append(embedding[0]) 48 | embeddings = np.array(embeddings) 49 | np.save(out_path,embeddings) 50 | logger.info(f'save embedding to {out_path}') 51 | logger.info(f'embedding {out_path} dimension: {embeddings.shape}') 52 | return embeddings 53 | 54 | def compute_similarity(self, query_embedding, embeddings): 55 | similarity_scores = np.matmul(embeddings, np.transpose(query_embedding)) 56 | return similarity_scores 57 | 58 | 59 | def get_topk_score_doc_ids(self, queries: List[str], k: int, task_names: List[str]) -> List[List[Tuple[float, str]]]: 60 | 61 | cur_corpus = self.corpus.filter(lambda x: x['task_name'] == task_names[0]) 62 | embeddings = self.get_baai_embeddings(cur_corpus, task_names[0]) 63 | 64 | tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5') 65 | model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5') 66 | model.eval() 67 | 68 | scores_all = [] 69 | 70 | logger.info(f'len(queries): {len(queries)}') 71 | # logger.info(f'query_embedding.shape: {query_embeddings.shape}') 72 | logger.info(f'embeddings.shape: {embeddings.shape}') 73 | 74 | start_time = time.time() 75 | query_embeddings = [] 76 | with torch.no_grad(): 77 | for query in tqdm(queries): 78 | encoded_input = tokenizer([query], padding=True, truncation=True, return_tensors='pt') 79 | model_output = model(**encoded_input) 80 | sentence_embeddings = model_output[0][:, 0][0] 81 | query_embeddings.append(sentence_embeddings) 82 | end_time = time.time() 83 | 84 | logger.info(f'Wall clock time of query embeddings computation: {end_time-start_time}') 85 | 86 | scores_all = np.matmul(query_embeddings, embeddings.transpose()) 87 | 88 | # for idx, query in tqdm(enumerate(queries)): 89 | # query_embedding = query_embeddings[idx] 90 | # scores = self.compute_similarity(query_embedding, embeddings) 91 | # scores_all.append(scores) 92 | # scores_all = np.array(scores_all) 93 | 94 | doc_ids = [] 95 | for entry in cur_corpus: 96 | doc_ids.append(entry['id']) 97 | 98 | topk_index = scores_all.argsort(axis=1)[:, -k:] 99 | result = [] 100 | for idx in range(len(topk_index)): 101 | cur_list = [] 102 | for j in range(len(topk_index[idx])): 103 | doc_id_ = doc_ids[topk_index[idx][j]] 104 | doc_id_score = scores_all[idx][topk_index[idx][j]] 105 | cur_list.append((doc_id_score, doc_id_)) 106 | result.append(cur_list) 107 | return result -------------------------------------------------------------------------------- /src/evaluation/base_eval.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List, Dict, Set, Tuple 3 | from collections import defaultdict 4 | from datasets import Dataset 5 | 6 | from config import Arguments 7 | 8 | 9 | class BaseEval: 10 | 11 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs): 12 | self.args: Arguments = args 13 | # id / contents / task_name 14 | self.corpus: Dataset = corpus 15 | 16 | self.task_name_to_doc_ids: Dict[str, Set[str]] = defaultdict(set) 17 | for doc_id, task_name in zip(self.corpus['id'], self.corpus['task_name']): 18 | self.task_name_to_doc_ids[task_name].add(doc_id) 19 | 20 | @abstractmethod 21 | def get_topk_score_doc_ids( 22 | self, queries: List[str], k: int, task_names: List[str] 23 | ) -> List[List[Tuple[float, str]]]: 24 | raise NotImplementedError 25 | 26 | def get_doc_ids_by_task_name(self, task_name: str) -> List[str]: 27 | return list(self.task_name_to_doc_ids[task_name]) 28 | 29 | def get_prompt_by_doc_ids(self, doc_ids: List[str]) -> str: 30 | return '\n\n'.join([self.corpus[int(doc_id)]['contents'] for doc_id in doc_ids]) 31 | -------------------------------------------------------------------------------- /src/evaluation/bm25_eval.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | from datasets import Dataset 3 | 4 | from evaluation.base_eval import BaseEval 5 | from config import Arguments 6 | from logger_config import logger 7 | from rank_bm25 import BM25Okapi 8 | from tqdm import tqdm 9 | import os 10 | import numpy as np 11 | 12 | class BM25Eval(BaseEval): 13 | 14 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs): 15 | super().__init__(args, corpus, **kwargs) 16 | 17 | self.corpus = corpus 18 | # initializing punctuations string 19 | self.punc = '''!()-[]{};:'"\,<>./?@#$%^&*_~''' 20 | 21 | self.cache_file_dir = 'outputs/bm25/' 22 | 23 | def tokenize(self, str): 24 | str = str.lower() 25 | for ele in str: 26 | if ele in self.punc: 27 | str = str.replace(ele, "") 28 | doc_tokens = str.split(" ") 29 | return doc_tokens 30 | 31 | 32 | 33 | def get_topk_score_doc_ids(self, queries: List[str], k: int, task_names: List[str]) -> List[List[Tuple[float, str]]]: 34 | tokenized_corpus = [] 35 | doc_ids = [] 36 | logger.info('Start tokenizing corpus for BM25') 37 | cur_corpus = self.corpus.filter(lambda x: x['task_name'] == task_names[0]) 38 | 39 | outpath_scores = os.path.join(self.cache_file_dir,f'{task_names[0]}_bm25_scores.npy') 40 | outpath_ids = os.path.join(self.cache_file_dir, f'{task_names[0]}_bm25_docids.npy') 41 | 42 | if os.path.exists(outpath_scores): 43 | logger.info(f'score file {outpath_scores} already exists, skip calculating') 44 | doc_scores_all = np.load(outpath_scores) 45 | doc_ids = np.load(outpath_ids) 46 | else: 47 | for entry in tqdm(cur_corpus): 48 | doc = entry['contents'] 49 | doc_tokens = self.tokenize(doc) 50 | tokenized_corpus.append(doc_tokens) 51 | doc_ids.append(entry['id']) 52 | 53 | bm25 = BM25Okapi(tokenized_corpus) 54 | doc_scores_all = [] 55 | for query in tqdm(queries): 56 | tokenized_query = self.tokenize(query) 57 | doc_scores = bm25.get_scores(tokenized_query) 58 | doc_scores_all.append(doc_scores) 59 | 60 | doc_scores_all = np.array(doc_scores_all) 61 | doc_ids = np.array(doc_ids) 62 | 63 | np.save(outpath_scores, doc_scores_all) 64 | np.save(outpath_ids, doc_ids) 65 | 66 | topk_index = doc_scores_all.argsort(axis=1)[:,-k:] 67 | result = [] 68 | for idx in range(len(topk_index)): 69 | cur_list = [] 70 | for j in range(len(topk_index[idx])): 71 | doc_id_ = doc_ids[topk_index[idx][j]] 72 | doc_id_score = doc_scores_all[idx][topk_index[idx][j]] 73 | cur_list.append((doc_id_score, doc_id_)) 74 | result.append(cur_list) 75 | return result 76 | -------------------------------------------------------------------------------- /src/evaluation/dense_eval.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | from datasets import Dataset 3 | 4 | from evaluation.base_eval import BaseEval 5 | from config import Arguments 6 | from logger_config import logger 7 | 8 | 9 | class DenseEval(BaseEval): 10 | 11 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs): 12 | super().__init__(args, corpus, **kwargs) 13 | 14 | input_prefix = 'query: ' if args.add_qd_prompt else '' 15 | # TODO: Hack 16 | is_e5_model = any(e5_name in args.model_name_or_path for e5_name in ['intfloat/e5', 'intfloat/multilingual-e5']) 17 | if is_e5_model and not input_prefix: 18 | logger.warning('E5 models need input prefix, set input_prefix = "query: "') 19 | input_prefix = 'query: ' 20 | 21 | from models import SimpleEncoder, SimpleRetriever 22 | encoder: SimpleEncoder = SimpleEncoder( 23 | model_name_or_path=args.model_name_or_path, 24 | l2_normalize=args.l2_normalize, 25 | prompt=input_prefix, 26 | ) 27 | cache_dir = '{}/embeddings/'.format(args.output_dir) 28 | 29 | self.retriever: SimpleRetriever = SimpleRetriever( 30 | encoder=encoder, 31 | corpus=corpus, 32 | cache_dir=cache_dir, 33 | ) 34 | 35 | def get_topk_score_doc_ids(self, queries: List[str], k: int, task_names: List[str]) -> List[List[Tuple[float, str]]]: 36 | assert len(queries) == len(task_names) 37 | 38 | query_idx_to_topk: Dict[int, List[Tuple]] = self.retriever.search_topk(queries=queries, top_k=k) 39 | for idx in range(len(queries)): 40 | q_task_name = task_names[idx] 41 | for j, (score, doc_id) in enumerate(query_idx_to_topk[idx]): 42 | if str(doc_id) not in self.task_name_to_doc_ids[q_task_name]: 43 | query_idx_to_topk[idx][j] = (score - 100., doc_id) 44 | query_idx_to_topk[idx] = sorted(query_idx_to_topk[idx], key=lambda x: x[0], reverse=True) 45 | 46 | topk_score_doc_ids: List[List[Tuple[float, str]]] = [] 47 | for idx in range(len(queries)): 48 | score_doc_ids: List[Tuple[float, str]] = query_idx_to_topk[idx][:k] 49 | score_doc_ids = [(score, str(doc_id)) for score, doc_id in score_doc_ids] 50 | topk_score_doc_ids.append(score_doc_ids) 51 | 52 | return topk_score_doc_ids 53 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from typing import List 4 | from rouge import Rouge 5 | from sklearn.metrics import f1_score 6 | 7 | from evaluation import qa_utils 8 | from logger_config import logger 9 | 10 | 11 | @torch.no_grad() 12 | def accuracy(output: torch.tensor, target: torch.tensor, topk=(1,)) -> List[float]: 13 | """Computes the accuracy over the k top predictions for the specified values of k""" 14 | maxk = max(topk) 15 | batch_size = target.size(0) 16 | 17 | _, pred = output.topk(maxk, 1, True, True) 18 | pred = pred.t() 19 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 20 | 21 | res = [] 22 | for k in topk: 23 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 24 | res.append(correct_k.mul_(100.0 / batch_size).item()) 25 | return res 26 | 27 | 28 | @torch.no_grad() 29 | def batch_mrr(output: torch.tensor, target: torch.tensor) -> float: 30 | assert len(output.shape) == 2 31 | assert len(target.shape) == 1 32 | sorted_score, sorted_indices = torch.sort(output, dim=-1, descending=True) 33 | _, rank = torch.nonzero(sorted_indices.eq(target.unsqueeze(-1)).long(), as_tuple=True) 34 | assert rank.shape[0] == output.shape[0] 35 | 36 | rank = rank + 1 37 | mrr = torch.sum(100 / rank.float()) / rank.shape[0] 38 | return mrr.item() 39 | 40 | 41 | ## =========================================================================== ## 42 | 43 | # Copy from https://github.com/microsoft/LMOps/tree/main/uprise/src/utils 44 | 45 | 46 | def rouge(preds, labels): 47 | # https://github.com/pltrdy/rouge 48 | r1s, r2s, rls = [], [], [] 49 | r = Rouge() 50 | for i in range(len(labels)): 51 | if '\n' not in preds[i]: preds[i] += '\n' 52 | if '\n' not in labels[i]: labels[i] += '\n' # avoid empty string 53 | scores = r.get_scores(preds[i], labels[i])[0] 54 | r1s.append(scores["rouge-1"]['f']) 55 | r2s.append(scores["rouge-2"]['f']) 56 | rls.append(scores["rouge-l"]['f']) 57 | r1 = sum(r1s) / len(r1s) 58 | r2 = sum(r2s) / len(r2s) 59 | rl = sum(rls) / len(rls) 60 | return r1, r2, rl 61 | 62 | 63 | def squad(labels, preds): 64 | """Computes SQuAD metrics, maximizing over answers per question. 65 | Args: 66 | labels: list of lists of strings 67 | preds: list of strings 68 | Returns: 69 | dict with score_key: squad score across all labels and predictions 70 | """ 71 | labels = [[qa_utils.normalize_squad(t) for t in u] for u in labels] 72 | preds = [qa_utils.normalize_squad(p) for p in preds] 73 | em, f1 = qa_utils.qa_metrics(labels, preds) # em,f1 74 | return em, f1 75 | 76 | 77 | def trivia_qa(labels, preds): 78 | """Computes TriviaQA metrics, maximizing over answers per question. 79 | Args: 80 | labels: list of lists of strings 81 | preds: list of strings 82 | Returns: 83 | dict with score_key: squad score across all labels and preds 84 | """ 85 | labels = [[qa_utils.normalize_trivia_qa(t) for t in u] for u in labels] 86 | preds = [qa_utils.normalize_trivia_qa(p) for p in preds] 87 | em, f1 = qa_utils.qa_metrics(labels, preds) # em,f1 88 | return em, f1 89 | 90 | 91 | def simple_accuracy(preds, labels): 92 | if isinstance(preds[0], str): 93 | labels = [label.lower().strip() for label in labels] 94 | preds = [pred.lower().strip() for pred in preds] 95 | res = [int(preds[i] == labels[i]) for i in range(len(preds))] 96 | acc = sum(res) / len(res) 97 | return acc 98 | 99 | 100 | def acc_and_f1(preds, labels): 101 | acc = simple_accuracy(preds, labels) 102 | # Currently only MRPC & QQP use this metric 103 | f1 = f1_score(y_true=labels, y_pred=preds, pos_label='Yes') 104 | return acc, f1, (acc + f1) / 2 105 | 106 | 107 | def compute_metrics(metric, labels, preds): 108 | assert len(preds) == len(labels) 109 | if metric == 'simple_accuracy': 110 | return {'acc': simple_accuracy(preds, labels) * 100} 111 | elif metric == 'rouge': 112 | r1, r2, rl = rouge(preds, labels) 113 | return {'r1': r1 * 100, 'r2': r2 * 100, 'rl': rl * 100} 114 | elif metric == 'acc_and_f1': 115 | acc, f1, acc_f1 = acc_and_f1(preds, labels) 116 | return {'acc': acc * 100, 'f1': f1 * 100, 'acc_and_f1': acc_f1 * 100} 117 | elif metric == 'f1': 118 | f1 = f1_score(y_true=labels, y_pred=preds, pos_label='Yes') 119 | return {'f1': f1 * 100} 120 | elif metric == 'squad': 121 | em, f1 = squad(labels=labels, preds=preds) 122 | return {'em': em, 'f1': f1} 123 | elif metric == 'trivia_qa': 124 | em, f1 = trivia_qa(labels=labels, preds=preds) 125 | return {'em': em, 'f1': f1} 126 | else: 127 | raise ValueError(metric) 128 | -------------------------------------------------------------------------------- /src/evaluation/qa_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5 Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for Question Answering (QA) evaluation. 16 | 17 | Matches results on the SQuAD (v1.1) and TriviaQA (v1.0) evaluation scripts. 18 | """ 19 | 20 | import collections 21 | import re 22 | import string 23 | 24 | from absl import logging 25 | import numpy as np 26 | 27 | 28 | def _normalize_answer(text, punc_chars, punc_repl): 29 | """Lower text and remove punctuation, articles and extra whitespace.""" 30 | 31 | def remove_articles(s): 32 | return re.sub(r"\b(a|an|the)\b", " ", s) 33 | 34 | def replace_punctuation(s): 35 | to_replace = set(punc_chars) 36 | return "".join(punc_repl if ch in to_replace else ch for ch in s) 37 | 38 | def white_space_fix(s): 39 | return " ".join(s.split()) 40 | 41 | text = text.lower() 42 | text = replace_punctuation(text) 43 | text = remove_articles(text) 44 | text = white_space_fix(text) 45 | 46 | return text 47 | 48 | 49 | def normalize_trivia_qa(answer): 50 | """Normalization used in official TriviaQA evaluation script.""" 51 | return _normalize_answer( 52 | answer, punc_chars=string.punctuation + "‘’´`_", punc_repl=" ").strip() 53 | 54 | 55 | def normalize_squad(answer): 56 | """Normalization used in official SQuAD evaluation script.""" 57 | return _normalize_answer(answer, punc_chars=string.punctuation, punc_repl="") 58 | 59 | 60 | def _metric_max_over_ground_truths(metric_fn, ground_truths, prediction): 61 | """Computes the maximum of the metric over all ground truths.""" 62 | return max( 63 | metric_fn(ground_truth, prediction) for ground_truth in ground_truths 64 | ) 65 | 66 | 67 | def _exact_match_score(target, prediction): 68 | return target == prediction 69 | 70 | 71 | def _f1_score(target, prediction): 72 | """Computes token f1 score for a single target and prediction.""" 73 | prediction_tokens = prediction.split() 74 | target_tokens = target.split() 75 | common = (collections.Counter(prediction_tokens) & 76 | collections.Counter(target_tokens)) 77 | num_same = sum(common.values()) 78 | if num_same == 0: 79 | return 0 80 | precision = 1.0 * num_same / len(prediction_tokens) 81 | recall = 1.0 * num_same / len(target_tokens) 82 | f1 = (2 * precision * recall) / (precision + recall) 83 | return f1 84 | 85 | def qa_metrics(targets, predictions): 86 | """Computes exact match and f1 QA scores, expecting pre-normalized text.""" 87 | if len(targets) != len(predictions): 88 | raise ValueError("Number of targets and predictions must match.") 89 | em = np.mean([ 90 | _metric_max_over_ground_truths(_exact_match_score, t, p) 91 | for p, t in zip(predictions, targets) 92 | ]) 93 | f1 = np.mean([ 94 | _metric_max_over_ground_truths(_f1_score, t, p) 95 | for p, t in zip(predictions, targets) 96 | ]) 97 | em *= 100 98 | f1 *= 100 99 | logging.info("EM = %.2f, F1 = %.2f", em, f1) 100 | return em, f1 -------------------------------------------------------------------------------- /src/evaluation/random_eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from typing import List, Tuple, Dict 4 | from datasets import Dataset 5 | 6 | from evaluation.base_eval import BaseEval 7 | from config import Arguments 8 | from logger_config import logger 9 | 10 | 11 | # This class randomly selects k-shot examples from the training set 12 | class RandomEval(BaseEval): 13 | 14 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs): 15 | super().__init__(args, corpus, **kwargs) 16 | self.all_doc_ids: List[str] = self.corpus['id'] 17 | self.cached_task_name_to_doc_ids: Dict[str, List[str]] = {} 18 | 19 | def get_topk_score_doc_ids(self, queries: List[str], k: int, task_names: List[str]) -> List[List[Tuple[float, str]]]: 20 | assert len(queries) == len(task_names) 21 | 22 | topk_score_doc_ids: List[List[Tuple[float, str]]] = [] 23 | for query, task_name in zip(queries, task_names): 24 | random_score_doc_ids: List[str] = self._single_get_topk_doc_ids(query, k, task_name) 25 | topk_score_doc_ids.append([(-1, doc_id) for doc_id in random_score_doc_ids]) 26 | 27 | return topk_score_doc_ids 28 | 29 | def _single_get_topk_doc_ids(self, query: str, k: int, task_name: str) -> List[str]: 30 | if task_name not in self.cached_task_name_to_doc_ids: 31 | self.cached_task_name_to_doc_ids[task_name] = self.get_doc_ids_by_task_name(task_name) 32 | doc_ids: List[str] = self.cached_task_name_to_doc_ids[task_name] 33 | # mnli_m & mnli_mm should retrieve from mnli training set 34 | if len(doc_ids) == 0 and task_name.startswith('mnli_'): 35 | if 'mnli' not in self.cached_task_name_to_doc_ids: 36 | self.cached_task_name_to_doc_ids['mnli'] = self.get_doc_ids_by_task_name('mnli') 37 | doc_ids = self.cached_task_name_to_doc_ids['mnli'] 38 | 39 | if len(doc_ids) == 0: 40 | logger.warning('Use the whole training set for task: {}'.format(task_name)) 41 | doc_ids = self.all_doc_ids 42 | 43 | if k >= len(doc_ids): 44 | logger.warning('k ({}) is larger than the number of examples ({})'.format(k, len(doc_ids))) 45 | k = min(k, len(doc_ids)) 46 | 47 | return random.sample(doc_ids, k) 48 | -------------------------------------------------------------------------------- /src/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kangmintong/C-RAG/5bda4382bebf6453b3b8dccb8a52f4d015e95fee/src/inference/__init__.py -------------------------------------------------------------------------------- /src/inference/gen_reward_scores.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import tqdm 4 | import torch 5 | 6 | from contextlib import nullcontext 7 | from torch.utils.data import DataLoader 8 | from functools import partial 9 | from datasets import Dataset, load_dataset 10 | from typing import Dict, List, Tuple 11 | from transformers.modeling_outputs import SequenceClassifierOutput 12 | from transformers import ( 13 | AutoTokenizer, 14 | PreTrainedTokenizerFast, 15 | DataCollatorWithPadding, 16 | HfArgumentParser, 17 | ) 18 | 19 | from config import Arguments 20 | from logger_config import logger 21 | from utils import move_to_device, save_dataset, wait_until_all_files_show_up 22 | from models import RerankerForInference 23 | from data_utils import load_corpus, save_to_readable_format 24 | from inference.inference_utils import reward_transform_func 25 | 26 | parser = HfArgumentParser((Arguments,)) 27 | args: Arguments = parser.parse_args_into_dataclasses()[0] 28 | kd_gen_score_in_path = os.path.join(args.data_dir, '{}.jsonl.gz'.format(args.kd_gen_score_split)) 29 | kd_gen_score_out_path = os.path.join(args.output_dir, 'kd_{}.jsonl.gz'.format(args.kd_gen_score_split)) 30 | 31 | 32 | def _get_shard_path(worker_idx: int) -> str: 33 | basename = os.path.basename(kd_gen_score_in_path) 34 | return '{}/shard_{}_{}'.format(args.output_dir, worker_idx, basename) 35 | 36 | 37 | @torch.no_grad() 38 | def _worker_gen_teacher_score(): 39 | gpu_idx: int = args.process_index 40 | dataset = load_dataset('json', data_files=kd_gen_score_in_path)['train'] 41 | if args.dry_run: 42 | dataset = dataset.select(range(100)) 43 | dataset = dataset.shard(num_shards=args.world_size, 44 | index=gpu_idx, 45 | contiguous=True) 46 | 47 | qid_pids = [] 48 | for ex in tqdm.tqdm(dataset, desc='get qid-pid pairs', mininterval=3): 49 | for doc_id in ex['doc_ids']: 50 | qid_pids.append((ex['query_id'], doc_id, ex['query'], ex['answers'], ex['options'])) 51 | 52 | inference_dataset = Dataset.from_dict({ 53 | 'query_id': [t[0] for t in qid_pids], 54 | 'doc_id': [t[1] for t in qid_pids], 55 | 'query': [t[2] for t in qid_pids], 56 | 'answers': [t[3] for t in qid_pids], 57 | 'options': [t[4] for t in qid_pids], 58 | }) 59 | 60 | query_ids, doc_ids = inference_dataset['query_id'], inference_dataset['doc_id'] 61 | 62 | logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(inference_dataset))) 63 | 64 | tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path) 65 | model: RerankerForInference = RerankerForInference.from_pretrained(args.model_name_or_path) 66 | model.eval() 67 | model.to(gpu_idx) 68 | 69 | corpus: Dataset = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz')) 70 | inference_dataset.set_transform(partial(reward_transform_func, tokenizer, args.reward_max_length, corpus)) 71 | 72 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None) 73 | data_loader = DataLoader( 74 | inference_dataset, 75 | batch_size=args.kd_gen_score_batch_size, 76 | shuffle=False, 77 | drop_last=False, 78 | num_workers=args.dataloader_num_workers, 79 | collate_fn=data_collator, 80 | pin_memory=True) 81 | 82 | scores = [] 83 | for batch_dict in tqdm.tqdm(data_loader, desc='generate teacher score', mininterval=5): 84 | batch_dict = move_to_device(batch_dict, device=gpu_idx) 85 | 86 | with torch.cuda.amp.autocast() if args.fp16 else nullcontext(): 87 | outputs: SequenceClassifierOutput = model(batch_dict) 88 | scores.append(outputs.logits.squeeze(dim=-1).cpu()) 89 | assert len(scores[-1].shape) == 1 90 | 91 | all_scores = torch.cat(scores, dim=-1) 92 | assert all_scores.shape[0] == len(inference_dataset), '{} != {}'.format(all_scores.shape[0], len(inference_dataset)) 93 | all_scores = all_scores.tolist() 94 | 95 | query_id_to_doc_id_scores: Dict[str, List[Tuple[str, float]]] = collections.defaultdict(list) 96 | for idx in range(len(query_ids)): 97 | query_id_to_doc_id_scores[query_ids[idx]].append((doc_ids[idx], round(all_scores[idx], 5))) 98 | 99 | def _update_score(ex: Dict) -> Dict: 100 | query_id = ex['query_id'] 101 | ex['doc_ids'] = [t[0] for t in query_id_to_doc_id_scores[query_id]] 102 | ex['doc_scores'] = [t[1] for t in query_id_to_doc_id_scores[query_id]] 103 | return ex 104 | 105 | dataset = dataset.map(_update_score, batched=False) 106 | save_dataset(dataset, _get_shard_path(gpu_idx)) 107 | 108 | logger.info('Done computing teacher score for worker {}'.format(gpu_idx)) 109 | 110 | 111 | def _merge_teacher_scores(): 112 | wait_until_all_files_show_up([_get_shard_path(worker_idx) for worker_idx in range(args.world_size)]) 113 | 114 | dataset = load_dataset( 115 | 'json', data_files=[_get_shard_path(worker_idx) for worker_idx in range(args.world_size)], split='train' 116 | ) 117 | 118 | save_dataset(dataset, kd_gen_score_out_path) 119 | logger.info('Writing teacher score to {}'.format(kd_gen_score_out_path)) 120 | 121 | logger.info('Done merge results') 122 | 123 | corpus = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz')) 124 | save_to_readable_format(in_path=kd_gen_score_out_path, corpus=corpus, shuffle=True) 125 | 126 | for worker_idx in range(args.world_size): 127 | os.remove(_get_shard_path(worker_idx)) 128 | 129 | 130 | def main(): 131 | logger.info('Args={}'.format(str(args))) 132 | if os.path.exists(kd_gen_score_out_path): 133 | logger.info('Found {}, skip'.format(kd_gen_score_out_path)) 134 | return 135 | 136 | logger.info('Use {} workers'.format(args.world_size)) 137 | _worker_gen_teacher_score() 138 | logger.info('Done batch generate teacher score') 139 | 140 | if args.process_index <= 0: 141 | _merge_teacher_scores() 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /src/inference/generate_few_shot_prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, 'src/') 4 | 5 | from datasets import Dataset, load_dataset, DownloadMode, concatenate_datasets 6 | from typing import List, Tuple 7 | from transformers import HfArgumentParser 8 | 9 | from config import Arguments 10 | from logger_config import logger 11 | from utils import save_dataset 12 | from evaluation import BaseEval 13 | from model_utils import build_eval_model 14 | from inference.inference_utils import get_prompt_save_path 15 | 16 | parser = HfArgumentParser((Arguments,)) 17 | args: Arguments = parser.parse_args_into_dataclasses()[0] 18 | 19 | 20 | def main(): 21 | out_path: str = get_prompt_save_path(args=args) 22 | 23 | if args.llm_k_shot==0: 24 | args.llm_k_shot=1 25 | 26 | if os.path.exists(out_path): 27 | logger.info('Prompt file {} exists. Skip.'.format(out_path)) 28 | return 29 | 30 | corpus: Dataset = load_dataset( 31 | 'json', data_files='{}/passages.jsonl.gz'.format(args.data_dir), split='train', 32 | download_mode=DownloadMode.FORCE_REDOWNLOAD 33 | ) 34 | # columns: query_id / query / answers / task_name 35 | eval_dataset: Dataset = load_dataset( 36 | 'json', data_files='{}/{}.jsonl.gz'.format(args.data_dir, args.llm_eval_split), split='train', 37 | download_mode=DownloadMode.FORCE_REDOWNLOAD 38 | ) 39 | 40 | 41 | if not args.llm_eval_tasks or args.llm_eval_tasks[0] == 'all': 42 | args.llm_eval_tasks = sorted(eval_dataset.unique('task_name')) 43 | logger.info('Eval all {} tasks'.format(len(args.llm_eval_tasks))) 44 | 45 | logger.info(args.llm_eval_tasks) 46 | 47 | model: BaseEval = build_eval_model(args=args, corpus=corpus) 48 | 49 | task_ds_list: List[Dataset] = [] 50 | for task_name in args.llm_eval_tasks: 51 | task_ds: Dataset = eval_dataset.filter(lambda x: x['task_name'] == task_name) 52 | 53 | if len(task_ds) > args.max_test_samples: 54 | logger.info('Task: {}, random sample {}/{} for evaluation'.format( 55 | task_name, args.max_test_samples, len(task_ds)) 56 | ) 57 | task_ds = task_ds.shuffle(seed=args.seed).select(range(args.max_test_samples)) 58 | logger.info('Task: {}, {} samples for evaluation'.format(task_name, len(task_ds))) 59 | 60 | if args.llm_k_shot <= 0: 61 | task_ds = task_ds.add_column('input_prompt', ['' for _ in range(len(task_ds))]) 62 | task_ds_list.append(task_ds) 63 | continue 64 | 65 | queries: List[str] = task_ds['query'] 66 | # Use a larger k in case we retrieve the docs from other tasks 67 | logger.info(f'args.llm_k_shot: {args.llm_k_shot}') 68 | topk_score_doc_ids: List[List[Tuple[float, str]]] = model.get_topk_score_doc_ids( 69 | queries, k=5 * args.llm_k_shot, task_names=task_ds['task_name'] 70 | ) 71 | # The most relevant doc should be close to the test example 72 | topk_score_doc_ids = [score_doc_ids[:args.llm_k_shot][::-1] for score_doc_ids in topk_score_doc_ids] 73 | topk_doc_ids: List[List[str]] = [ 74 | [doc_id for _, doc_id in score_doc_ids] for score_doc_ids in topk_score_doc_ids 75 | ] 76 | topk_scores: List[List[float]] = [ 77 | [round(score, 4) for score, _ in score_doc_ids] for score_doc_ids in topk_score_doc_ids 78 | ] 79 | input_prompts: List[str] = [model.get_prompt_by_doc_ids(doc_ids) for doc_ids in topk_doc_ids] 80 | task_ds = task_ds.add_column('input_prompt', input_prompts) 81 | task_ds = task_ds.add_column('topk_doc_ids', topk_doc_ids) 82 | task_ds = task_ds.add_column('topk_scores', topk_scores) 83 | task_ds_list.append(task_ds) 84 | 85 | few_shot_ds: Dataset = concatenate_datasets(task_ds_list) 86 | save_dataset(few_shot_ds, out_path) 87 | logger.info('Save {} examples to {}'.format(len(few_shot_ds), out_path)) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /src/inference/inference_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from typing import Dict, List 4 | from datasets import Dataset 5 | from transformers.file_utils import PaddingStrategy 6 | from transformers import PreTrainedTokenizerFast, BatchEncoding 7 | 8 | from config import Arguments 9 | 10 | 11 | def get_prompt_save_path(args: Arguments) -> str: 12 | from model_utils import parse_model_id 13 | model_id: str = parse_model_id(args.model_name_or_path) 14 | out_path: str = '{}/{}_{}_k{}.jsonl.gz'.format( 15 | args.output_dir, model_id, args.llm_eval_split, args.llm_k_shot 16 | ) 17 | 18 | return out_path 19 | 20 | 21 | def reward_transform_func( 22 | tokenizer: PreTrainedTokenizerFast, 23 | reward_max_length: int, 24 | corpus: Dataset, 25 | examples: Dict[str, List]) -> BatchEncoding: 26 | input_docs: List[str] = [] 27 | 28 | # ATTENTION: this code should be consistent with RerankDataLoader 29 | for doc_id in examples['doc_id']: 30 | doc_id = int(doc_id) 31 | input_docs.append(corpus[doc_id]['contents'].strip()) 32 | 33 | input_queries = [] 34 | for query, answers, options in zip(examples['query'], examples['answers'], examples['options']): 35 | current_query = query 36 | if len(options) > 1: 37 | current_query += '\n' + options[ord(answers[0]) - ord('A')] 38 | else: 39 | current_query += '\n' + random.choice(answers) 40 | input_queries.append(current_query) 41 | 42 | batch_dict = tokenizer(input_queries, 43 | text_pair=input_docs, 44 | max_length=reward_max_length, 45 | padding=PaddingStrategy.DO_NOT_PAD, 46 | return_token_type_ids=False, 47 | truncation=True) 48 | 49 | return batch_dict 50 | -------------------------------------------------------------------------------- /src/inference/search_topk.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | sys.path.insert(0, 'src/') 5 | 6 | from datasets import Dataset, load_dataset, DownloadMode 7 | from typing import Dict, List, Tuple 8 | from transformers import HfArgumentParser 9 | 10 | from config import Arguments 11 | from logger_config import logger 12 | from utils import save_dataset 13 | from data_utils import save_to_readable_format 14 | from evaluation import BaseEval 15 | from model_utils import build_eval_model, parse_model_id 16 | 17 | parser = HfArgumentParser((Arguments,)) 18 | args: Arguments = parser.parse_args_into_dataclasses()[0] 19 | assert args.do_search, 'This script is only for search mode.' 20 | 21 | 22 | def main(): 23 | model_id = parse_model_id(args.model_name_or_path) 24 | out_path: str = '{}/{}_{}.jsonl.gz'.format(args.output_dir, model_id, args.search_split) 25 | if os.path.exists(out_path): 26 | logger.info('Output file {} already exists. Skip.'.format(out_path)) 27 | return 28 | 29 | data_path: str = '{}/{}.jsonl.gz'.format(args.data_dir, args.search_split) 30 | assert os.path.exists(data_path), 'Data file {} does not exist.'.format(data_path) 31 | dataset: Dataset = load_dataset( 32 | 'json', data_files=data_path, split='train', download_mode=DownloadMode.FORCE_REDOWNLOAD 33 | ) 34 | if args.dry_run: 35 | dataset = dataset.shuffle(seed=args.seed).select(range(100)) 36 | logger.info('Load {} examples from {}'.format(len(dataset), data_path)) 37 | 38 | corpus_path: str = '{}/passages.jsonl.gz'.format(args.data_dir) 39 | corpus: Dataset = load_dataset( 40 | 'json', data_files=corpus_path, split='train', download_mode=DownloadMode.FORCE_REDOWNLOAD 41 | ) 42 | if args.dry_run: 43 | corpus = corpus.select(range(4096)) 44 | logger.info('Load {} candidates from {}'.format(len(corpus), corpus_path)) 45 | 46 | retriever: BaseEval = build_eval_model(args=args, corpus=corpus) 47 | 48 | logger.info('Search top {} candidates for each example.'.format(args.search_topk)) 49 | topk_score_doc_ids: List[List[Tuple[float, str]]] = retriever.get_topk_score_doc_ids( 50 | dataset['query'], k=args.search_topk, task_names=dataset['task_name'] 51 | ) 52 | all_contents: List[str] = corpus['contents'] 53 | 54 | def _map_func(example: Dict, idx: int) -> Dict: 55 | score_doc_ids: List[Tuple[float, str]] = topk_score_doc_ids[idx] 56 | # Exclude the example itself from the top-k candidates. 57 | score_doc_ids = [t for t in score_doc_ids if not all_contents[int(t[1])].startswith(example['query'])] 58 | np.random.shuffle(score_doc_ids) 59 | return { 60 | 'doc_ids': [doc_id for _, doc_id in score_doc_ids], 61 | 'doc_scores': [round(doc_score, 4) for doc_score, _ in score_doc_ids], 62 | } 63 | 64 | dataset = dataset.map(_map_func, with_indices=True, num_proc=1, desc='Add top-k candidates') 65 | dataset = dataset.filter(lambda example: len(example['doc_ids']) > 1) 66 | 67 | save_dataset(dataset, out_path, shuffle='train' in args.search_split) 68 | save_to_readable_format(in_path=out_path, corpus=corpus, shuffle=True) 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /src/llms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_llm import BaseLLM 2 | from .gpt2 import GPT2 3 | from .gpt_neo import GPTNeo 4 | from .llama import Llama 5 | -------------------------------------------------------------------------------- /src/llms/base_llm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from abc import abstractmethod 5 | from typing import List, Optional, Union 6 | 7 | 8 | class BaseLLM(nn.Module): 9 | 10 | def __init__(self, model_name_or_path: str, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.model_name_or_path = model_name_or_path 13 | 14 | @abstractmethod 15 | def batch_score(self, input_texts: List[str], output_texts: List[str], **kwargs) -> List[float]: 16 | raise NotImplementedError 17 | 18 | def score(self, input_text: str, output_text: str, **kwargs) -> float: 19 | return self.batch_score([input_text], [output_text], **kwargs)[0] 20 | 21 | @abstractmethod 22 | def batch_decode(self, input_texts: List[str], **kwargs) -> List[str]: 23 | raise NotImplementedError 24 | 25 | def decode(self, input_text: str, **kwargs) -> str: 26 | return self.batch_decode([input_text], **kwargs)[0] 27 | 28 | def cuda(self, device: Optional[Union[int, torch.device]] = 0): 29 | self.model.to(device) 30 | return self 31 | -------------------------------------------------------------------------------- /src/llms/gpt_neo.py: -------------------------------------------------------------------------------- 1 | from logger_config import logger 2 | from config import Arguments 3 | from llms.gpt2 import GPT2 4 | 5 | 6 | class GPTNeo(GPT2): 7 | 8 | def __init__(self, args: Arguments, model_name_or_path: str = 'EleutherAI/gpt-neo-2.7B', **kwargs): 9 | super().__init__(args, model_name_or_path, **kwargs) 10 | -------------------------------------------------------------------------------- /src/llms/llama.py: -------------------------------------------------------------------------------- 1 | from logger_config import logger 2 | from config import Arguments 3 | from llms.gpt2 import GPT2 4 | 5 | 6 | class Llama(GPT2): 7 | 8 | def __init__(self, args: Arguments, model_name_or_path: str = 'huggyllama/llama-7b', **kwargs): 9 | super().__init__(args, model_name_or_path, **kwargs) 10 | -------------------------------------------------------------------------------- /src/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .biencoder_dataloader import RetrievalDataLoader 2 | from .cross_encoder_dataloader import CrossEncoderDataLoader 3 | -------------------------------------------------------------------------------- /src/loaders/cross_encoder_dataloader.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torch 4 | 5 | from copy import deepcopy 6 | from functools import partial 7 | from typing import Dict, List, Optional 8 | from datasets import load_dataset, Dataset 9 | from transformers.file_utils import PaddingStrategy 10 | from transformers import PreTrainedTokenizerFast, Trainer 11 | 12 | from config import Arguments 13 | from logger_config import logger 14 | from utils import get_input_files 15 | from .loader_utils import group_doc_ids, filter_invalid_examples 16 | from data_utils import to_positive_negative_format 17 | 18 | 19 | class CrossEncoderDataset(torch.utils.data.Dataset): 20 | 21 | def __init__(self, input_files: List[str], args: Arguments, 22 | tokenizer: PreTrainedTokenizerFast): 23 | self.args = args 24 | self.input_files = input_files 25 | self.negative_size = args.train_n_passages - 1 26 | assert self.negative_size > 0 27 | self.tokenizer = tokenizer 28 | corpus_path = os.path.join(os.path.dirname(self.input_files[0]), 'passages.jsonl.gz') 29 | self.corpus: Dataset = load_dataset('json', data_files=corpus_path, split='train') 30 | 31 | self.dataset: Dataset = load_dataset('json', data_files=self.input_files, split='train') 32 | with self.args.main_process_first(desc="pre-processing"): 33 | self.dataset = filter_invalid_examples(args, self.dataset) 34 | self.dataset = self.dataset.map( 35 | partial(to_positive_negative_format, 36 | topk_as_positive=args.topk_as_positive, 37 | bottomk_as_negative=args.bottomk_as_negative), 38 | load_from_cache_file=args.world_size > 1, 39 | desc='to_positive_negative_format', 40 | remove_columns=['doc_ids', 'doc_scores'] 41 | ) 42 | 43 | if self.args.max_train_samples is not None: 44 | self.dataset = self.dataset.select(range(self.args.max_train_samples)) 45 | # Log a few random samples from the training set: 46 | for index in random.sample(range(len(self.dataset)), 1): 47 | logger.info(f"Sample {index} of the training set: {self.dataset[index]}.") 48 | 49 | self.dataset.set_transform(self._transform_func) 50 | 51 | # use its state to decide which positives/negatives to sample 52 | self.trainer: Optional[Trainer] = None 53 | 54 | def __len__(self): 55 | return len(self.dataset) 56 | 57 | def __getitem__(self, idx): 58 | return self.dataset[idx] 59 | 60 | def _transform_func(self, examples: Dict[str, List]) -> Dict[str, List]: 61 | current_epoch = int(self.trainer.state.epoch or 0) if self.trainer is not None else 0 62 | 63 | examples = deepcopy(examples) 64 | # add some random negatives if not enough 65 | for idx in range(len(examples['query_id'])): 66 | while len(examples['negatives'][idx]['doc_id']) < self.negative_size: 67 | random_doc_id = str(random.randint(0, len(self.corpus) - 1)) 68 | examples['negatives'][idx]['doc_id'].append(random_doc_id) 69 | examples['negatives'][idx]['score'].append(-100.) 70 | 71 | input_doc_ids = group_doc_ids( 72 | examples=examples, 73 | negative_size=self.negative_size, 74 | offset=current_epoch + self.args.seed 75 | ) 76 | assert len(input_doc_ids) == len(examples['query']) * self.args.train_n_passages 77 | 78 | input_queries, input_docs = [], [] 79 | for idx, doc_id in enumerate(input_doc_ids): 80 | input_docs.append(self.corpus[doc_id]['contents'].strip()) 81 | # For reward model, the left side is the query + ground truth answer 82 | q_idx = idx // self.args.train_n_passages 83 | current_query: str = examples['query'][q_idx] 84 | answers, options = examples['answers'][q_idx], examples['options'][q_idx] 85 | if len(options) > 1: 86 | current_query += '\n' + options[ord(answers[0]) - ord('A')] 87 | # logger.info('current_query: %s', current_query) 88 | else: 89 | current_query += '\n' + random.choice(answers) 90 | input_queries.append(current_query) 91 | 92 | batch_dict = self.tokenizer(input_queries, 93 | text_pair=input_docs, 94 | max_length=self.args.reward_max_length, 95 | return_token_type_ids=False, 96 | padding=PaddingStrategy.DO_NOT_PAD, 97 | truncation=True) 98 | 99 | packed_batch_dict = {} 100 | for k in batch_dict: 101 | packed_batch_dict[k] = [] 102 | assert len(examples['query']) * self.args.train_n_passages == len(batch_dict[k]) 103 | for idx in range(len(examples['query'])): 104 | start = idx * self.args.train_n_passages 105 | packed_batch_dict[k].append(batch_dict[k][start:(start + self.args.train_n_passages)]) 106 | 107 | return packed_batch_dict 108 | 109 | 110 | class CrossEncoderDataLoader: 111 | 112 | def __init__(self, args: Arguments, tokenizer: PreTrainedTokenizerFast): 113 | self.args = args 114 | self.tokenizer = tokenizer 115 | self.train_dataset = self._get_transformed_datasets() 116 | 117 | def set_trainer(self, trainer: Trainer): 118 | if self.train_dataset is not None: 119 | self.train_dataset.trainer = trainer 120 | 121 | def _get_transformed_datasets(self) -> CrossEncoderDataset: 122 | train_dataset = None 123 | 124 | if self.args.train_file is not None: 125 | train_input_files = get_input_files(self.args.train_file) 126 | logger.info("Train files: {}".format(train_input_files)) 127 | train_dataset = CrossEncoderDataset( 128 | args=self.args, 129 | tokenizer=self.tokenizer, 130 | input_files=train_input_files, 131 | ) 132 | 133 | if self.args.do_train: 134 | assert train_dataset is not None, "Training requires a train dataset" 135 | 136 | return train_dataset 137 | -------------------------------------------------------------------------------- /src/loaders/loader_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from datasets import Dataset 3 | 4 | from config import Arguments 5 | 6 | 7 | def _slice_with_mod(elements: List, offset: int, cnt: int) -> List: 8 | return [elements[(offset + idx) % len(elements)] for idx in range(cnt)] 9 | 10 | 11 | def filter_invalid_examples(args: Arguments, dataset: Dataset) -> Dataset: 12 | def _filter_func(example: Dict) -> bool: 13 | if len(example['doc_ids']) <= args.topk_as_positive: 14 | return False 15 | if example['task_name'] in args.held_out_tasks: 16 | return False 17 | 18 | sorted_doc_scores = sorted(example['doc_scores'], reverse=True) 19 | if sorted_doc_scores[args.topk_as_positive - 1] <= -100.: 20 | return False 21 | 22 | return True 23 | 24 | return dataset.filter( 25 | _filter_func, 26 | load_from_cache_file=args.world_size > 1, 27 | ) 28 | 29 | 30 | def group_doc_ids(examples: Dict[str, List], 31 | negative_size: int, 32 | offset: int) -> List[int]: 33 | pos_doc_ids: List[int] = [] 34 | positives: List[Dict[str, List]] = examples['positives'] 35 | for idx, ex_pos in enumerate(positives): 36 | all_pos_doc_ids = ex_pos['doc_id'] 37 | cur_pos_doc_id = _slice_with_mod(all_pos_doc_ids, offset=offset, cnt=1)[0] 38 | pos_doc_ids.append(int(cur_pos_doc_id)) 39 | 40 | neg_doc_ids: List[List[int]] = [] 41 | negatives: List[Dict[str, List]] = examples['negatives'] 42 | for ex_neg in negatives: 43 | cur_neg_doc_ids = _slice_with_mod(ex_neg['doc_id'], 44 | offset=offset * negative_size, 45 | cnt=negative_size) 46 | cur_neg_doc_ids = [int(doc_id) for doc_id in cur_neg_doc_ids] 47 | neg_doc_ids.append(cur_neg_doc_ids) 48 | 49 | assert len(pos_doc_ids) == len(neg_doc_ids), '{} != {}'.format(len(pos_doc_ids), len(neg_doc_ids)) 50 | assert all(len(doc_ids) == negative_size for doc_ids in neg_doc_ids) 51 | 52 | input_doc_ids: List[int] = [] 53 | for pos_doc_id, neg_ids in zip(pos_doc_ids, neg_doc_ids): 54 | input_doc_ids.append(pos_doc_id) 55 | input_doc_ids += neg_ids 56 | 57 | return input_doc_ids 58 | -------------------------------------------------------------------------------- /src/logger_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from transformers.trainer_callback import TrainerCallback 5 | 6 | 7 | def _setup_logger(): 8 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 9 | logger = logging.getLogger() 10 | logger.setLevel(logging.INFO) 11 | # logger.setLevel(logging.DEBUG) 12 | 13 | console_handler = logging.StreamHandler() 14 | console_handler.setFormatter(log_format) 15 | 16 | data_dir = './data/' 17 | os.makedirs(data_dir, exist_ok=True) 18 | file_handler = logging.FileHandler('{}/log.txt'.format(data_dir)) 19 | file_handler.setFormatter(log_format) 20 | 21 | logger.handlers = [console_handler, file_handler] 22 | 23 | return logger 24 | 25 | 26 | logger = _setup_logger() 27 | 28 | 29 | class LoggerCallback(TrainerCallback): 30 | def on_log(self, args, state, control, logs=None, **kwargs): 31 | _ = logs.pop("total_flos", None) 32 | if state.is_world_process_zero: 33 | logger.info(logs) 34 | -------------------------------------------------------------------------------- /src/main_eval.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset, load_dataset, DownloadMode 2 | from transformers import HfArgumentParser 3 | 4 | from config import Arguments 5 | from logger_config import logger 6 | from llms import BaseLLM 7 | from model_utils import build_llm 8 | from data_utils import log_task_statistics 9 | from llm_evaluator import LLMEvaluator 10 | from inference.inference_utils import get_prompt_save_path 11 | 12 | parser = HfArgumentParser((Arguments,)) 13 | args: Arguments = parser.parse_args_into_dataclasses()[0] 14 | 15 | 16 | def main(): 17 | # columns: query_id / query / answers / task_name / input_prompt 18 | eval_dataset: Dataset = load_dataset( 19 | 'json', data_files=get_prompt_save_path(args), split='train', 20 | download_mode=DownloadMode.FORCE_REDOWNLOAD 21 | ) 22 | if not args.llm_eval_tasks or args.llm_eval_tasks[0] == 'all': 23 | args.llm_eval_tasks = sorted(eval_dataset.unique('task_name')) 24 | logger.info('Eval all {} tasks'.format(len(args.llm_eval_tasks))) 25 | 26 | 27 | if args.process_index <= 0: 28 | log_task_statistics(eval_dataset, split=args.llm_eval_split) 29 | logger.info('{} tasks to evaluate: {}'.format(len(args.llm_eval_tasks), args.llm_eval_tasks)) 30 | 31 | 32 | llm: BaseLLM = build_llm(args=args) 33 | llm.cuda(args.process_index) 34 | 35 | evaluator: LLMEvaluator = LLMEvaluator(args=args, llm=llm) 36 | 37 | for task_name in args.llm_eval_tasks: 38 | logger.info('Evaluating task: {}'.format(task_name)) 39 | evaluator.eval_single_task(eval_dataset, task_name) 40 | if args.dry_run: 41 | break 42 | 43 | logger.info('Done') 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /src/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datasets import Dataset 4 | 5 | from llms import BaseLLM, GPT2, GPTNeo, Llama 6 | from evaluation import BaseEval, RandomEval, DenseEval, OpenaiEval, BM25Eval, BAAIEval 7 | from config import Arguments 8 | from logger_config import logger 9 | 10 | 11 | def build_llm(args: Arguments) -> BaseLLM: 12 | model_name_or_path: str = args.llm_model_name_or_path 13 | if 'gpt2' in model_name_or_path: 14 | if args.llm_max_input_length >= 1024: 15 | args.llm_max_input_length -= max(args.llm_max_decode_length, 128) 16 | logger.warning('GPT2 models cannot handle sequences longer than 1024. ' 17 | 'set to {}'.format(args.llm_max_input_length)) 18 | llm = GPT2(args=args, model_name_or_path=model_name_or_path) 19 | elif 'gpt-neo' in model_name_or_path: 20 | llm = GPTNeo(args=args, model_name_or_path=model_name_or_path) 21 | elif 'llama' in model_name_or_path: 22 | llm = Llama(args=args, model_name_or_path=model_name_or_path) 23 | else: 24 | raise ValueError('Invalid model name or path: {}'.format(model_name_or_path)) 25 | 26 | return llm 27 | 28 | 29 | def build_eval_model(args: Arguments, corpus: Dataset) -> BaseEval: 30 | model_name_or_path: str = args.model_name_or_path 31 | if model_name_or_path == 'random': 32 | return RandomEval(args=args, corpus=corpus) 33 | elif model_name_or_path == 'BM25': 34 | return BM25Eval(args=args, corpus=corpus) 35 | elif model_name_or_path == 'OpenAI': 36 | return OpenaiEval(args=args, corpus=corpus) 37 | elif model_name_or_path == 'BAAI': 38 | return BAAIEval(args=args, corpus=corpus) 39 | elif 'llm-retriever-base' in model_name_or_path: 40 | # LLM-R 41 | return DenseEval(args=args, corpus=corpus) 42 | else: 43 | raise TypeError(f'retrieval method {model_name_or_path} not implemented') 44 | 45 | 46 | def parse_model_id(model_name_or_path: str) -> str: 47 | if model_name_or_path in ['random', 'BM25', 'Openai']: 48 | return model_name_or_path 49 | return os.path.basename(model_name_or_path.strip('/'))[-12:] 50 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .biencoder_model import BiencoderModel, BiencoderOutput 2 | from .cross_encoder_model import Reranker, RerankerForInference 3 | from .simple_encoder import SimpleEncoder 4 | from .simple_retriever import SimpleRetriever 5 | -------------------------------------------------------------------------------- /src/models/biencoder_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from dataclasses import dataclass 7 | from typing import Optional, Dict, Tuple 8 | from torch import Tensor 9 | from transformers import PreTrainedModel, AutoModel 10 | from transformers.modeling_outputs import ModelOutput 11 | 12 | from config import Arguments 13 | from logger_config import logger 14 | from utils import dist_gather_tensor, select_grouped_indices, full_contrastive_scores_and_labels, pool 15 | 16 | 17 | @dataclass 18 | class BiencoderOutput(ModelOutput): 19 | q_reps: Optional[Tensor] = None 20 | p_reps: Optional[Tensor] = None 21 | loss: Optional[Tensor] = None 22 | labels: Optional[Tensor] = None 23 | scores: Optional[Tensor] = None 24 | 25 | 26 | class BiencoderModel(nn.Module): 27 | def __init__(self, args: Arguments, 28 | lm_q: PreTrainedModel, 29 | lm_p: PreTrainedModel): 30 | super().__init__() 31 | self.lm_q = lm_q 32 | self.lm_p = lm_p 33 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') 34 | self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) 35 | self.args = args 36 | 37 | from trainers import BiencoderTrainer 38 | self.trainer: Optional[BiencoderTrainer] = None 39 | 40 | self._freeze_position_embedding_if_needed(self.lm_q) 41 | self._freeze_position_embedding_if_needed(self.lm_p) 42 | 43 | def forward(self, batch_dict: Dict[str, Tensor]) -> BiencoderOutput: 44 | assert self.args.process_index >= 0 45 | 46 | scores, labels, q_reps, p_reps, all_scores, all_labels = self._compute_scores(batch_dict) 47 | 48 | start = self.args.process_index * q_reps.shape[0] 49 | group_indices = select_grouped_indices(scores=scores, 50 | group_size=self.args.train_n_passages, 51 | start=start * self.args.train_n_passages) 52 | 53 | if not self.args.do_kd_biencoder: 54 | # training biencoder from scratch 55 | loss = self.cross_entropy(scores, labels) 56 | else: 57 | # training biencoder with kd 58 | # batch_size x train_n_passage 59 | group_scores = torch.gather(input=scores, dim=1, index=group_indices) 60 | assert group_scores.shape[1] == self.args.train_n_passages 61 | group_log_scores = torch.log_softmax(group_scores, dim=-1) 62 | kd_log_target = torch.log_softmax(batch_dict['kd_labels'], dim=-1) 63 | 64 | kd_loss = self.kl_loss_fn(input=group_log_scores, target=kd_log_target) 65 | ce_loss = self.cross_entropy(scores, labels) 66 | loss = self.args.kd_cont_loss_weight * ce_loss + kd_loss 67 | 68 | total_n_psg = self.args.world_size * q_reps.shape[0] * self.args.train_n_passages 69 | 70 | return BiencoderOutput(loss=loss, q_reps=q_reps, p_reps=p_reps, 71 | labels=labels.contiguous(), 72 | scores=scores[:, :total_n_psg].contiguous()) 73 | 74 | def _compute_scores(self, batch_dict: Dict[str, Tensor]) -> Tuple: 75 | embeds = self._encode(self.lm_p, batch_dict) 76 | batch_size = batch_dict['input_ids'].shape[0] // (self.args.train_n_passages + 1) 77 | q_reps = embeds[:batch_size] 78 | p_reps = embeds[batch_size:] 79 | assert p_reps.shape[0] == q_reps.shape[0] * self.args.train_n_passages 80 | 81 | all_q_reps = dist_gather_tensor(q_reps) 82 | all_p_reps = dist_gather_tensor(p_reps) 83 | assert all_p_reps.shape[0] == self.args.world_size * q_reps.shape[0] * self.args.train_n_passages 84 | 85 | all_scores, all_labels = full_contrastive_scores_and_labels( 86 | query=all_q_reps, key=all_p_reps, 87 | use_all_pairs=self.args.full_contrastive_loss 88 | ) 89 | if self.args.l2_normalize: 90 | all_scores = all_scores / self.args.t 91 | 92 | start = self.args.process_index * q_reps.shape[0] 93 | local_query_indices = torch.arange(start, start + q_reps.shape[0], dtype=torch.long).to(q_reps.device) 94 | # batch_size x (world_size x batch_size x train_n_passage) 95 | scores = all_scores.index_select(dim=0, index=local_query_indices) 96 | labels = all_labels.index_select(dim=0, index=local_query_indices) 97 | 98 | return scores, labels, q_reps, p_reps, all_scores, all_labels 99 | 100 | def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch.Tensor]: 101 | if not input_dict: 102 | return None 103 | outputs = encoder(**{k: v for k, v in input_dict.items() if k not in ['labels', 'kd_labels']}, return_dict=True) 104 | embeds = pool(last_hidden_states=outputs.last_hidden_state, 105 | attention_mask=input_dict['attention_mask'], 106 | pool_type=self.args.pool_type) 107 | if self.args.l2_normalize: 108 | embeds = F.normalize(embeds, dim=-1) 109 | return embeds.contiguous() 110 | 111 | def _freeze_position_embedding_if_needed(self, model: nn.Module): 112 | if self.args.freeze_position_embedding: 113 | for name, param in model.named_parameters(): 114 | if 'position_embeddings' in name: 115 | param.requires_grad = False 116 | logger.info('Freeze {}'.format(name)) 117 | 118 | def gradient_checkpointing_enable(self): 119 | self.lm_q.gradient_checkpointing_enable() 120 | 121 | @classmethod 122 | def build(cls, args: Arguments, **hf_kwargs): 123 | if os.path.isdir(args.model_name_or_path): 124 | logger.info(f'loading shared model weight from {args.model_name_or_path}') 125 | lm_q = AutoModel.from_pretrained(args.model_name_or_path) 126 | lm_p = lm_q 127 | 128 | model = cls(args=args, lm_q=lm_q, lm_p=lm_p) 129 | return model 130 | 131 | def save(self, output_dir: str): 132 | self.lm_q.save_pretrained(output_dir) 133 | -------------------------------------------------------------------------------- /src/models/cross_encoder_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import Optional, Dict 5 | from transformers import ( 6 | PreTrainedModel, 7 | AutoModelForSequenceClassification 8 | ) 9 | from transformers.modeling_outputs import SequenceClassifierOutput 10 | 11 | from config import Arguments 12 | 13 | 14 | class Reranker(nn.Module): 15 | def __init__(self, hf_model: PreTrainedModel, args: Arguments): 16 | super().__init__() 17 | self.hf_model = hf_model 18 | self.args = args 19 | 20 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') 21 | 22 | def forward(self, batch: Dict[str, torch.Tensor]) -> SequenceClassifierOutput: 23 | input_batch_dict = {k: v for k, v in batch.items() if k != 'labels'} 24 | 25 | outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True) 26 | outputs.logits = outputs.logits.view(-1, self.args.train_n_passages) 27 | loss = self.cross_entropy(outputs.logits, batch['labels']) 28 | outputs.loss = loss 29 | 30 | return outputs 31 | 32 | def gradient_checkpointing_enable(self): 33 | self.hf_model.gradient_checkpointing_enable() 34 | 35 | @classmethod 36 | def from_pretrained(cls, all_args: Arguments, *args, **kwargs): 37 | hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs) 38 | return cls(hf_model, all_args) 39 | 40 | def save_pretrained(self, output_dir: str): 41 | self.hf_model.save_pretrained(output_dir) 42 | 43 | 44 | class RerankerForInference(nn.Module): 45 | def __init__(self, hf_model: Optional[PreTrainedModel] = None): 46 | super().__init__() 47 | self.hf_model = hf_model 48 | self.hf_model.eval() 49 | 50 | @torch.no_grad() 51 | def forward(self, batch) -> SequenceClassifierOutput: 52 | return self.hf_model(**batch) 53 | 54 | @classmethod 55 | def from_pretrained(cls, pretrained_model_name_or_path: str): 56 | hf_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path) 57 | return cls(hf_model) 58 | -------------------------------------------------------------------------------- /src/models/simple_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import tqdm 4 | 5 | from functools import partial 6 | from torch.utils.data import DataLoader 7 | from datasets import Dataset 8 | from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding, PreTrainedTokenizerFast, BatchEncoding 9 | from transformers.modeling_outputs import BaseModelOutput 10 | from typing import List, Dict 11 | 12 | from utils import pool, move_to_cuda 13 | 14 | 15 | def _transform_func(tokenizer: PreTrainedTokenizerFast, 16 | examples: Dict[str, List], 17 | prompt: str = None) -> BatchEncoding: 18 | if prompt: 19 | examples['input_texts'] = [prompt + t for t in examples['input_texts']] 20 | batch_dict = tokenizer( 21 | examples['input_texts'], 22 | max_length=256, 23 | return_token_type_ids=False, 24 | padding=True, 25 | truncation=True, 26 | ) 27 | 28 | return batch_dict 29 | 30 | 31 | class SimpleEncoder(torch.nn.Module): 32 | def __init__(self, model_name_or_path: str, 33 | l2_normalize: bool = True, 34 | pool_type: str = 'avg', 35 | prompt: str = 'query: '): 36 | super().__init__() 37 | self.model_name_or_path = model_name_or_path 38 | self.encoder = AutoModel.from_pretrained(model_name_or_path) 39 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 40 | self.gpu_count = torch.cuda.device_count() 41 | 42 | self.l2_normalize = l2_normalize 43 | self.pool_type = pool_type 44 | self.prompt = prompt 45 | assert self.prompt in ['', 'query: ', 'passage: '] 46 | 47 | self.encoder.eval() 48 | self.encoder.cuda() 49 | 50 | 51 | self.gpu_count = 1 52 | # if self.gpu_count > 1: 53 | # self.encoder = torch.nn.DataParallel(self.encoder) 54 | 55 | @torch.no_grad() 56 | def encode(self, sentences: List[str], **kwargs) -> torch.Tensor: 57 | dataset: Dataset = Dataset.from_dict({'input_texts': sentences}) 58 | dataset.set_transform(partial(_transform_func, self.tokenizer, prompt=self.prompt)) 59 | 60 | data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8) 61 | data_loader = DataLoader( 62 | dataset, 63 | batch_size=128 * self.gpu_count, 64 | shuffle=False, 65 | drop_last=False, 66 | num_workers=2, 67 | collate_fn=data_collator, 68 | pin_memory=True) 69 | 70 | encoded_embeds = [] 71 | for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10, disable=len(sentences) < 128): 72 | batch_dict = move_to_cuda(batch_dict) 73 | 74 | with torch.cuda.amp.autocast(): 75 | outputs: BaseModelOutput = self.encoder(**batch_dict) 76 | embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], self.pool_type) 77 | if self.l2_normalize: 78 | embeds = F.normalize(embeds, p=2, dim=-1) 79 | encoded_embeds.append(embeds.cpu()) 80 | 81 | return torch.cat(encoded_embeds, dim=0) 82 | -------------------------------------------------------------------------------- /src/models/simple_retriever.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import torch 4 | 5 | from typing import List, Dict, Union 6 | from datasets import Dataset 7 | from collections import defaultdict 8 | 9 | from models.simple_encoder import SimpleEncoder 10 | from logger_config import logger 11 | 12 | 13 | def _sharded_search_topk( 14 | query_embeds: torch.Tensor, top_k: int, 15 | shard_embed: torch.Tensor, shard_idx: int, 16 | idx_offset: int) -> Dict[int, List]: 17 | query_idx_to_topk: Dict[int, List] = defaultdict(list) 18 | search_batch_size = 256 19 | query_indices = list(range(query_embeds.shape[0])) 20 | 21 | for start in tqdm.tqdm(range(0, query_embeds.shape[0], search_batch_size), 22 | desc="search shard {}".format(shard_idx), 23 | mininterval=5): 24 | batch_query_embed = query_embeds[start:(start + search_batch_size)] 25 | batch_query_indices = query_indices[start:(start + search_batch_size)] 26 | batch_score = torch.mm(batch_query_embed, shard_embed.t()) 27 | batch_sorted_score, batch_sorted_indices = torch.topk(batch_score, k=top_k, dim=-1, largest=True) 28 | for batch_idx, query_idx in enumerate(batch_query_indices): 29 | cur_scores = batch_sorted_score[batch_idx].cpu().tolist() 30 | cur_indices = [str(idx + idx_offset) for idx in batch_sorted_indices[batch_idx].cpu().tolist()] 31 | query_idx_to_topk[query_idx] += list(zip(cur_scores, cur_indices)) 32 | query_idx_to_topk[query_idx] = sorted(query_idx_to_topk[query_idx], key=lambda t: -t[0])[:top_k] 33 | 34 | return query_idx_to_topk 35 | 36 | 37 | class SimpleRetriever: 38 | 39 | def __init__(self, encoder: SimpleEncoder, 40 | corpus: Union[Dataset, List[str]], 41 | cache_dir: str = None): 42 | self.encoder = encoder 43 | 44 | # Encode the "contents" column of the corpus 45 | if isinstance(corpus, List): 46 | corpus = Dataset.from_dict({'contents': corpus}) 47 | self.corpus: Dataset = corpus 48 | logger.info(f"Corpus size: {len(self.corpus)}") 49 | 50 | self.cache_dir = cache_dir or 'tmp-{}/'.format(len(corpus)) 51 | os.makedirs(self.cache_dir, exist_ok=True) 52 | logger.info(f"Cache dir: {self.cache_dir}") 53 | self.encode_shard_size = 2_000_000 54 | 55 | def search_topk(self, queries: List[str], top_k: int = 10) -> Dict[int, List]: 56 | # encode the corpus or load from cache if it already exists 57 | self._encode_corpus_if_necessary(shard_size=self.encode_shard_size) 58 | 59 | # encode the queries 60 | query_embeds = self._encode_queries(queries) 61 | if torch.cuda.is_available(): 62 | query_embeds = query_embeds.cuda() 63 | 64 | # search the top-k results 65 | query_idx_to_topk: Dict[int, List] = defaultdict(list) 66 | num_shards = (len(self.corpus) + self.encode_shard_size - 1) // self.encode_shard_size 67 | idx_offset = 0 68 | for shard_idx in range(num_shards): 69 | out_path: str = self._get_out_path(shard_idx) 70 | shard_embeds = torch.load(out_path, map_location=lambda storage, loc: storage) 71 | shard_embeds = shard_embeds.to(query_embeds.device) 72 | shard_query_idx_to_topk = _sharded_search_topk( 73 | query_embeds=query_embeds, 74 | top_k=top_k, 75 | shard_embed=shard_embeds, 76 | shard_idx=shard_idx, 77 | idx_offset=idx_offset 78 | ) 79 | for query_idx, shard_topk in shard_query_idx_to_topk.items(): 80 | query_idx_to_topk[query_idx] += shard_topk 81 | query_idx_to_topk[query_idx] = sorted(query_idx_to_topk[query_idx], key=lambda t: -t[0])[:top_k] 82 | 83 | idx_offset += shard_embeds.shape[0] 84 | 85 | return query_idx_to_topk 86 | 87 | def encode_corpus(self): 88 | self._encode_corpus_if_necessary(shard_size=self.encode_shard_size) 89 | logger.info('Done encoding corpus') 90 | 91 | def _get_out_path(self, shard_idx: int) -> str: 92 | return '{}/shard_{}'.format(self.cache_dir, shard_idx) 93 | 94 | def _encode_corpus_if_necessary(self, shard_size: int): 95 | num_shards = (len(self.corpus) + shard_size - 1) // shard_size 96 | num_examples = 0 97 | for shard_idx in range(num_shards): 98 | out_path: str = self._get_out_path(shard_idx) 99 | if os.path.exists(out_path): 100 | logger.info('{} already exists, will skip encoding'.format(out_path)) 101 | num_examples += len(torch.load(out_path, map_location=lambda storage, loc: storage)) 102 | continue 103 | shard_dataset: Dataset = self.corpus.shard( 104 | num_shards=num_shards, 105 | index=shard_idx, 106 | contiguous=True 107 | ) 108 | shard_embeds: torch.Tensor = self.encoder.encode( 109 | sentences=shard_dataset['contents'] 110 | ) 111 | 112 | num_examples += shard_embeds.shape[0] 113 | logger.info('Saving shard {} ({} examples) to {}'.format(shard_idx, len(shard_dataset), out_path)) 114 | torch.save(shard_embeds, out_path) 115 | 116 | assert num_examples == len(self.corpus), \ 117 | f"Number of examples in the corpus ({len(self.corpus)}) " \ 118 | f"does not match the number of examples in the shards ({num_examples})" 119 | 120 | def _encode_queries(self, queries: List[str]) -> torch.Tensor: 121 | return self.encoder.encode( 122 | sentences=queries 123 | ) 124 | -------------------------------------------------------------------------------- /src/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Optional 2 | 3 | 4 | def to_letter(key: Union[str, int]) -> str: 5 | key = str(key).upper().strip() 6 | num_to_letter = {"0": "A", "1": "B", "2": "C", "3": "D"} 7 | assert key in num_to_letter or key in ['A', 'B', 'C', 'D'], f'Unknown answer key: {key}' 8 | return num_to_letter.get(key, key) 9 | 10 | 11 | def format_options(options: List[str]) -> str: 12 | assert len(options) <= 4, f'Number of options should be less than 4, but got {len(options)}' 13 | res = 'OPTIONS: ' 14 | for letter, option in zip(['A', 'B', 'C', 'D'], options): 15 | res += f'{letter}) {option} ' 16 | 17 | return res.strip() 18 | 19 | 20 | # Based on https://github.com/microsoft/LMOps/blob/main/uprise/DPR/dpr/utils/tasks.py 21 | 22 | TASK_TYPE_TO_TASK_NAME = { 23 | 'close_qa': ['natural_questions', 'arc_c', 'arc_e'], 24 | 'common_reason': ['copa', 'piqa', 'hellaswag'], 25 | 'coreference': ['winogrande', 'wsc', 'wsc273'], 26 | 'nli': ['rte', 'mnli', 'mnli_m', 'mnli_mm', 'qnli', 'snli'], 27 | 'paraphrase': ['mrpc', 'paws', 'qqp'], 28 | 'reading': ['multirc', 'openbookqa', 'squad_v1', 'boolq'], 29 | 'sentiment': ['yelp', 'sentiment140', 'sst2'], 30 | 'struct2text': ['common_gen', 'e2e_nlg', 'dart'], 31 | 'summarize': ['aeslc', 'ag_news', 'gigaword'], 32 | } 33 | 34 | 35 | class App: 36 | def __init__(self): 37 | self.cls_dic = {} 38 | 39 | def add(self, key): 40 | def adder(cls): 41 | self.cls_dic[key] = cls 42 | return cls 43 | 44 | return adder 45 | 46 | 47 | task_map = App() 48 | 49 | 50 | from .base_task import BaseTask 51 | from .aeslc import Aeslc 52 | from .agnews import Ag_news 53 | from .arc import Arc_c, Arc_e 54 | from .boolq import Boolq 55 | from .common_gen import Common_gen 56 | from .copa import Copa 57 | from .dart import Dart 58 | from .e2e_nlg import E2e_nlg 59 | from .gigaword import Gigaword 60 | from .hellaswag import Hellaswag 61 | from .mnli import Mnli, Mnli_m, Mnli_mm 62 | from .mrpc import Mrpc 63 | from .multirc import Multirc 64 | from .nq import Natural_questions 65 | from .openbookqa import Openbookqa 66 | from .paws import Paws 67 | from .piqa import Piqa 68 | from .qnli import Qnli 69 | from .qqp import Qqp 70 | from .rte import Rte 71 | from .sentiment140 import Sentiment140 72 | from .snli import Snli 73 | from .squad_v1 import Squad_v1 74 | from .sst2 import Sst2 75 | from .winogrande import Winogrande 76 | from .wsc import Wsc 77 | from .wsc273 import Wsc273 78 | from .yelp import Yelp 79 | 80 | 81 | def get_metric_name_by_task_name(task_name: str) -> str: 82 | assert task_name in task_map.cls_dic, f'Unknown task name: {task_name}' 83 | return task_map.cls_dic[task_name]().metric_name 84 | 85 | 86 | def get_possible_answers_by_task_name(task_name: str) -> Optional[List[str]]: 87 | assert task_name in task_map.cls_dic, f'Unknown task name: {task_name}' 88 | return task_map.cls_dic[task_name]().possible_answers 89 | 90 | 91 | def parse_decoded_text_by_task(decoded_text: str, task_name: str) -> str: 92 | # TODO: maybe add some task-specific logics here 93 | return decoded_text.strip().split('\n')[0] 94 | -------------------------------------------------------------------------------- /src/tasks/aeslc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("aeslc") 9 | class Aeslc(BaseTask): 10 | 11 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 12 | split = split if split == 'train' else 'test' 13 | # For some reason, huggingface reports "checksum mismatch", so we ignore the checksum for now 14 | dataset: Dataset = load_dataset('aeslc', split=split, ignore_verifications=True) 15 | dataset = dataset.rename_column('email_body', 'body') 16 | dataset = dataset.rename_column('subject_line', 'subject') 17 | 18 | def _remove_newlines(example: Dict) -> Dict: 19 | example['body'] = ' '.join(example['body'].split()) 20 | example['subject'] = ' '.join(example['subject'].split()) 21 | return example 22 | 23 | dataset = dataset.map(_remove_newlines, desc='remove newlines') 24 | 25 | # filter logic from uprise. For FLAN, it filters out empty examples 26 | def _filter_func(example: Dict) -> bool: 27 | return 0 < len(example['body'].split()) <= 256 and 0 < len(example['subject'].split()) <= 256 28 | 29 | dataset = dataset.filter(_filter_func) 30 | 31 | return dataset 32 | 33 | @property 34 | def templates(self) -> List[Tuple[str, str]]: 35 | return [ 36 | ("What is the subject line for this email? {body}", "{subject}"), 37 | ("Write a subject line for this message: {body}", "{subject}"), 38 | ("{body} Write a subject line for this email.", "{subject}"), 39 | ("Here is an email: {body} What is a potential subject line for this email?", "{subject}"), 40 | ("{body} Propose a subject line for this email?", "{subject}"), 41 | ("This is the content of an email: {body} What was the subject line for this email?", "{subject}"), 42 | ("This is an email {body} What is the subject of this email?", "{subject}"), 43 | ("{body} Generate a subject line for this email.", "{subject}"), 44 | ] 45 | 46 | @property 47 | def possible_answers(self) -> Optional[List[str]]: 48 | return None 49 | 50 | @property 51 | def metric_name(self) -> str: 52 | return 'rouge' 53 | 54 | @property 55 | def task_name(self) -> str: 56 | return 'aeslc' 57 | 58 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 59 | return example['subject'] 60 | 61 | -------------------------------------------------------------------------------- /src/tasks/agnews.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("ag_news") 9 | class Ag_news(BaseTask): 10 | 11 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 12 | split = split if split == 'train' else 'test' 13 | dataset = load_dataset('ag_news', split=split) 14 | return dataset 15 | 16 | @property 17 | def templates(self) -> List[Tuple[str, str]]: 18 | return [ 19 | ("\"{text}\" What is this text about? World, Sports, Business, or Technology?", "{answer}"), 20 | ("\"{text}\" Which topic is this article about? World, Sports, Business, or Technology?", "{answer}"), 21 | ("\"{text}\" Which is the best summary of this article? World, Sports, Business, or Technology?", 22 | "{answer}"), 23 | ("\"{text}\" What is this text about? World, Sports, Business, or Technology?", "{answer}"), 24 | ( 25 | "\"{text}\" What best summarizes the content of the above article? World, Sports, Business, or Technology?", 26 | "{answer}"), 27 | ("Which is this about? \"{text}\" World, Sports, Business, or Technology?", "{answer}"), 28 | ("Which is an appropriate title for this article? \"{text}\" World, Sports, Business, or Technology?", 29 | "{answer}"), 30 | ("Select the topic that this about: \"{text}\" World, Sports, Business, or Technology?", "{answer}"), 31 | ] 32 | 33 | @property 34 | def possible_answers(self) -> Optional[List[str]]: 35 | return ['World', 'Sports', 'Business', 'Technology'] 36 | 37 | @property 38 | def metric_name(self) -> str: 39 | return 'simple_accuracy' 40 | 41 | @property 42 | def task_name(self) -> str: 43 | return 'ag_news' 44 | -------------------------------------------------------------------------------- /src/tasks/arc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map, to_letter 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("arc_c") 9 | class Arc_c(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'test' 12 | dataset = load_dataset('ai2_arc', 'ARC-Challenge', split=split) 13 | 14 | # Both FLAN & uprise & structured prompting have this filter logic 15 | dataset = dataset.filter(lambda ex: len(ex['choices']['text']) == 4) 16 | 17 | def _map_func(ex: Dict) -> Dict: 18 | if ex['answerKey'] not in ['A', 'B', 'C', 'D']: 19 | ex["answerKey"] = to_letter(int(ex['answerKey']) - 1) 20 | ex['options'] = ex['choices']['text'] 21 | return ex 22 | 23 | dataset = dataset.map(_map_func) 24 | 25 | return dataset 26 | 27 | @property 28 | def templates(self) -> List[Tuple[str, str]]: 29 | return [ 30 | ("{question}", "{answer}"), 31 | ("Question: {question} Answer:", "{answer}"), 32 | ("Question: {question} What is the correct answer to the question from the following choices?", "{answer}"), 33 | ("Q: {question} What is the correct answer to this question?", "{answer}"), 34 | ("What is the answer? {question}", "{answer}"), 35 | ("Answer the question {question}", "{answer}"), 36 | ("{question} Pick the answer from these options.", "{answer}"), 37 | ] 38 | 39 | @property 40 | def possible_answers(self) -> Optional[List[str]]: 41 | return ['A', 'B', 'C', 'D'] 42 | 43 | @property 44 | def metric_name(self) -> str: 45 | return 'simple_accuracy' 46 | 47 | @property 48 | def task_name(self) -> str: 49 | return 'arc_c' 50 | 51 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 52 | return example['answerKey'] 53 | 54 | 55 | @task_map.add("arc_e") 56 | class Arc_e(Arc_c): 57 | 58 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 59 | split = split if split == 'train' else 'test' 60 | dataset = load_dataset('ai2_arc', 'ARC-Easy', split=split) 61 | dataset = dataset.filter(lambda ex: len(ex['choices']['text']) == 4) 62 | 63 | def _map_func(ex: Dict) -> Dict: 64 | if ex['answerKey'] not in ['A', 'B', 'C', 'D']: 65 | ex["answerKey"] = to_letter(int(ex['answerKey']) - 1) 66 | ex['options'] = ex['choices']['text'] 67 | return ex 68 | 69 | dataset = dataset.map(_map_func) 70 | return dataset 71 | 72 | @property 73 | def task_name(self) -> str: 74 | return 'arc_e' 75 | -------------------------------------------------------------------------------- /src/tasks/base_task.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from typing import List, Optional, Tuple, Dict, Union 4 | from datasets import Dataset 5 | 6 | from logger_config import logger 7 | 8 | 9 | class BaseTask(object): 10 | 11 | def __init__(self, template_idx: int = 0, **kwargs): 12 | self.template_idx = template_idx 13 | 14 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 15 | raise NotImplementedError 16 | 17 | def get_task_data(self, split: str) -> Optional[Dataset]: 18 | # columns: query_id / query / answers / task_name 19 | dataset = self._load_raw_data(split) 20 | if not dataset: 21 | return None 22 | 23 | logger.info('Load dataset: {}, split: {}'.format(self.task_name, split)) 24 | dataset = dataset.map(self.map_single, num_proc=4) 25 | dataset = dataset.add_column( 26 | 'query_id', ['{}_{}_{}'.format(self.task_name, split, idx) for idx in range(len(dataset))] 27 | ) 28 | dataset = dataset.remove_columns( 29 | column_names=[col for col in dataset.column_names 30 | if col not in ['query_id', 'query', 'answers', 'options', 'task_name']] 31 | ) 32 | 33 | return dataset 34 | 35 | def get_corpus(self) -> Optional[Dataset]: 36 | # columns: contents / task_name 37 | corpus = self.get_task_data(split='train') 38 | if not corpus: 39 | return None 40 | 41 | def _map_func(example: Dict) -> Dict: 42 | answer = example['answers'][0] 43 | if len(example['options']) > 1: 44 | # multiple-choice tasks 45 | option_idx = self.possible_answers.index(answer) 46 | answer = example['options'][option_idx] 47 | return { 48 | 'contents': '\n'.join([example['query'], answer]), 49 | } 50 | 51 | corpus = corpus.map( 52 | _map_func, num_proc=4, 53 | remove_columns=[col for col in corpus.column_names if col not in ['contents', 'task_name']], 54 | desc='{} corpus'.format(self.task_name) 55 | ) 56 | 57 | return corpus 58 | 59 | def get_template(self) -> Tuple[str, str]: 60 | return self.templates[self.template_idx % len(self.templates)] 61 | 62 | def map_single(self, example: Dict) -> Dict: 63 | # ("If \"{premise}\", can we conclude that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"), 64 | query_template, answer_template = self.get_template() 65 | 66 | # find the key in {key} format using regular expression 67 | query_keys = re.findall(r'\{(\w+)\}', query_template) 68 | answer_keys = re.findall(r'\{(\w+)\}', answer_template) 69 | 70 | # replace the key with the value in the example 71 | query: str = query_template.format(**{key: example[key] for key in query_keys}) 72 | assert len(answer_keys) == 1, "Only one answer key is allowed" 73 | answer_key = answer_keys[0] 74 | del answer_keys 75 | 76 | example[answer_key]: Union[str, List[str]] = self.get_answer(example) 77 | if isinstance(example[answer_key], str): 78 | answers: List[str] = [answer_template.format(**{answer_key: example[answer_key]})] 79 | elif isinstance(example[answer_key], list): 80 | answers: List[str] = [answer_template.format(**{answer_key: ans}) for ans in example[answer_key]] 81 | else: 82 | raise ValueError(f"Unknown answer type: {example[answer_key]}") 83 | 84 | return { 85 | 'query': query, 86 | 'options': example.get('options', ['']), 87 | 'answers': answers, 88 | 'task_name': self.task_name, 89 | } 90 | 91 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 92 | # Many tasks need to override this default implementation 93 | assert int(example['label']) >= 0, "label must be non-negative" 94 | return self.possible_answers[int(example['label'])] 95 | 96 | @property 97 | def templates(self) -> List[Tuple[str, str]]: 98 | raise NotImplementedError 99 | 100 | @property 101 | def possible_answers(self) -> Optional[List[str]]: 102 | raise NotImplementedError 103 | 104 | @property 105 | def metric_name(self) -> str: 106 | raise NotImplementedError 107 | 108 | @property 109 | def task_name(self) -> str: 110 | raise NotImplementedError 111 | -------------------------------------------------------------------------------- /src/tasks/boolq.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("boolq") 9 | class Boolq(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('super_glue', 'boolq', split=split) 13 | dataset = dataset.rename_column('passage', 'text') 14 | return dataset 15 | 16 | @property 17 | def templates(self) -> List[Tuple[str, str]]: 18 | return [ 19 | ("{text} Can we conclude that {question}?", "{answer}"), 20 | ("{text} Is it true that {question}?", "{answer}"), 21 | ("{text} {question}?", "{answer}"), 22 | ("Text: {text} Question: {question}?", "{answer}"), 23 | ("{text} What's the best answer to this question: {question}?", "{answer}"), 24 | ("{text} Based on the above text, what's the best answer to this question: {question}?", "{answer}"), 25 | ("{text} Answer this question, making sure that the answer is supposed by the text: {question}?", 26 | "{answer}"), 27 | ("{text} Is the following statement correct based on the text {question}", "{answer}"), 28 | ("{text} Is this statement correct \"{question}\"?", "{answer}"), 29 | ("Is it true that {question} based on the following text? {text}", "{answer}"), 30 | ] 31 | 32 | @property 33 | def possible_answers(self) -> Optional[List[str]]: 34 | return ['No', 'Yes'] 35 | 36 | @property 37 | def metric_name(self) -> str: 38 | return 'simple_accuracy' 39 | 40 | @property 41 | def task_name(self) -> str: 42 | return 'boolq' 43 | -------------------------------------------------------------------------------- /src/tasks/common_gen.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("common_gen") 9 | class Common_gen(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('common_gen', split=split) 13 | dataset = dataset.map(lambda ex: {'concepts': ", ".join(ex["concepts"])}) 14 | return dataset 15 | 16 | @property 17 | def templates(self) -> List[Tuple[str, str]]: 18 | return [ 19 | ("Concepts: {concepts}. Write a sentence that includes all these words.", "{target}"), 20 | ("Keywords: {concepts}. What is a sentence that includes all these keywords?", "{target}"), 21 | ("Here are some concepts: {concepts}. What is a sentence about these concepts?", "{target}"), 22 | ("Produce a sentence which mentions all of these concepts: {concepts}.", "{target}"), 23 | ("Write a sentence about the following things: {concepts}.", "{target}"), 24 | ("Generate a sentence that includes all the following words: {concepts}.", "{target}"), 25 | ] 26 | 27 | @property 28 | def possible_answers(self) -> Optional[List[str]]: 29 | return None 30 | 31 | @property 32 | def metric_name(self) -> str: 33 | return 'rouge' 34 | 35 | @property 36 | def task_name(self) -> str: 37 | return 'common_gen' 38 | 39 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 40 | return example['target'] 41 | -------------------------------------------------------------------------------- /src/tasks/copa.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map, to_letter 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("copa") 9 | class Copa(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('super_glue', 'copa', split=split) 13 | 14 | def _map_func(ex: Dict) -> Dict: 15 | ex['options'] = [ex['choice1'], ex['choice2']] 16 | return ex 17 | 18 | dataset = dataset.map(_map_func) 19 | 20 | return dataset 21 | 22 | @property 23 | def templates(self) -> List[Tuple[str, str]]: 24 | # question is either "cause" or "effect" 25 | return [ 26 | ("\"{premise}\" What is the {question}?", "{answer}"), 27 | ("Here is a premise: \"{premise}\" What is the {question}?", "{answer}"), 28 | ("\"{premise}\" What is the {question} of the preceding sentence?", "{answer}"), 29 | ("\"{premise}\" What is a plausible {question}?", "{answer}"), 30 | ("Based on the following sentence, what is the {question}? \"{premise}\"", "{answer}"), 31 | ("\"{premise}\" {question}:", "{answer}"), 32 | ("What is the {question} of the following sentence? \"{premise}\"", "{answer}"), 33 | ("Answer the following question about this sentence: \"{premise}\" What is the {question}?", "{answer}"), 34 | ] 35 | 36 | @property 37 | def possible_answers(self) -> Optional[List[str]]: 38 | return ['A', 'B'] 39 | 40 | @property 41 | def metric_name(self) -> str: 42 | return 'simple_accuracy' 43 | 44 | @property 45 | def task_name(self) -> str: 46 | return 'copa' 47 | 48 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 49 | return to_letter(str(example['label'])) 50 | -------------------------------------------------------------------------------- /src/tasks/dart.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from typing import Optional, List, Tuple, Dict, Union 4 | from datasets import load_dataset, Dataset 5 | 6 | from tasks import task_map 7 | from tasks.base_task import BaseTask 8 | 9 | 10 | @task_map.add("dart") 11 | class Dart(BaseTask): 12 | ''' 13 | https://huggingface.co/datasets/GEM/dart 14 | ''' 15 | 16 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 17 | split = split if split == 'train' else 'validation' 18 | dataset = load_dataset('GEM/dart', split=split) 19 | 20 | def _map_func(ex: Dict) -> Dict: 21 | tripleset = "; ".join([", ".join(triplet) for triplet in ex["tripleset"]]) 22 | # Get rid of some undesirable cells like "[TABLECONTEXT]", "[TITLE]" 23 | tripleset = re.sub(r'\[(.*?)\]', '', tripleset) 24 | return { 25 | 'tripleset': tripleset 26 | } 27 | 28 | dataset = dataset.map(_map_func) 29 | 30 | return dataset 31 | 32 | @property 33 | def templates(self) -> List[Tuple[str, str]]: 34 | return [ 35 | ("Triple: {tripleset} What is a sentence that describes this triple?", "{target}"), 36 | ("Data: {tripleset} What would a sentence about this data be like?", "{target}"), 37 | ("Generate an approximately fifteen-word sentence that describes all this data: {tripleset}", "{target}"), 38 | ("Here is some data: {tripleset}. Write a sentence that describes this data", "{target}"), 39 | ("This is some data: {tripleset}. Generate a detailed description of this data", "{target}"), 40 | ("Generate a sentence about this data: {tripleset}", "{target}"), 41 | ("Write a sentence that about [{tripleset}].", "{target}"), 42 | ("Produce a long descriptive sentence that uses all these words: {tripleset}", "{target}"), 43 | ] 44 | 45 | @property 46 | def possible_answers(self) -> Optional[List[str]]: 47 | return None 48 | 49 | @property 50 | def metric_name(self) -> str: 51 | return 'rouge' 52 | 53 | @property 54 | def task_name(self) -> str: 55 | return 'dart' 56 | 57 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 58 | return example['target'] 59 | -------------------------------------------------------------------------------- /src/tasks/e2e_nlg.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from typing import Optional, List, Tuple, Dict, Union 4 | from datasets import load_dataset, Dataset 5 | 6 | from tasks import task_map 7 | from tasks.base_task import BaseTask 8 | 9 | 10 | @task_map.add("e2e_nlg") 11 | class E2e_nlg(BaseTask): 12 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 13 | split = split if split == 'train' else 'test' 14 | dataset = load_dataset('GEM/e2e_nlg', split=split) 15 | 16 | def _map_func(ex: Dict) -> Dict: 17 | meaning_representation = re.sub(r'\[', ' = ', ex['meaning_representation']) 18 | meaning_representation = re.sub(r'\]', '', meaning_representation) 19 | return { 20 | 'meaning_representation': meaning_representation 21 | } 22 | 23 | dataset = dataset.map(_map_func) 24 | 25 | return dataset 26 | 27 | @property 28 | def templates(self) -> List[Tuple[str, str]]: 29 | return [ 30 | ("Attributes: {meaning_representation}. Produce a detailed sentence about this restaurant.", "{target}"), 31 | ("Data: {meaning_representation}. Can you generate a sentence about this data?", "{target}"), 32 | ("Data: {meaning_representation}. What is a sentence that describe this data?", "{target}"), 33 | ( 34 | "Here are some keywords about a restaurant: {meaning_representation}. Write a sentence that describes the following attributes of a restaurant.", 35 | "{target}"), 36 | ( 37 | "Here is some data about a restaurant: {meaning_representation}. Write a sentence that includes the following data about a restaurant.", 38 | "{target}"), 39 | ("Sentence: {meaning_representation}. Can you represent the content in this sentence in data form?", 40 | "{target}"), 41 | ("Write a sentence about a restaurant with all the following attributes: {meaning_representation}.", 42 | "{target}"), 43 | ("Write a sentence that is about a restaurant with all the following properties: {meaning_representation}.", 44 | "{target}"), 45 | ("Produce a detailed sentence about a restaurant using the following words: {meaning_representation}.", 46 | "{target}"), 47 | ("Generate a descriptive sentence about a restaurant using the following words: {meaning_representation}.", 48 | "{target}"), 49 | ] 50 | 51 | @property 52 | def possible_answers(self) -> Optional[List[str]]: 53 | return None 54 | 55 | @property 56 | def metric_name(self) -> str: 57 | return 'rouge' 58 | 59 | @property 60 | def task_name(self) -> str: 61 | return 'e2e_nlg' 62 | 63 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 64 | return example['target'] 65 | -------------------------------------------------------------------------------- /src/tasks/gigaword.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("gigaword") 9 | class Gigaword(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'test' 12 | dataset = load_dataset('gigaword', split=split) 13 | 14 | def _filter_func(ex: Dict) -> bool: 15 | text = ''.join([ex['document'], ex['summary']]) 16 | no_unk = 'UNK' not in text 17 | no_hashtag = '#' not in text 18 | return no_unk and no_hashtag 19 | 20 | dataset = dataset.filter(_filter_func) 21 | dataset = dataset.rename_column('document', 'text') 22 | 23 | return dataset 24 | 25 | @property 26 | def templates(self) -> List[Tuple[str, str]]: 27 | return [ 28 | ("Write a short summary for this text: {text}", "{summary}"), 29 | ("Briefly summarize this sentence: {text}", "{summary}"), 30 | ("Generate a short summary this sentence: {text}", "{summary}"), 31 | ("What is a shorter version of this: {text}", "{summary}"), 32 | ("{text} Write a brief summary in a sentence or less", "{summary}"), 33 | ("{text} What is a very short summary of the above text?", "{summary}"), 34 | ("{text} Summarize the aforementioned text in a single phrase.", "{summary}"), 35 | ("{text} Can you generate a short summary of the above paragraph?", "{summary}"), 36 | ] 37 | 38 | @property 39 | def possible_answers(self) -> Optional[List[str]]: 40 | return None 41 | 42 | @property 43 | def metric_name(self) -> str: 44 | return 'rouge' 45 | 46 | @property 47 | def task_name(self) -> str: 48 | return 'gigaword' 49 | 50 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 51 | return example['summary'] 52 | -------------------------------------------------------------------------------- /src/tasks/hellaswag.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from typing import Optional, List, Tuple, Dict, Union 4 | from datasets import load_dataset, Dataset 5 | 6 | from tasks import task_map, to_letter 7 | from tasks.base_task import BaseTask 8 | 9 | 10 | @task_map.add("hellaswag") 11 | class Hellaswag(BaseTask): 12 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 13 | split = split if split == 'train' else 'validation' 14 | dataset = load_dataset('hellaswag', split=split) 15 | 16 | def _map_func(ex: Dict) -> Dict: 17 | ex['ctx'] = re.sub(r'\[.*?\]\s', '', ex['ctx']).strip() 18 | ex['options'] = [re.sub(r'\[.*?\]\s', '', option) for option in ex['endings']] 19 | return ex 20 | 21 | dataset = dataset.map(_map_func) 22 | dataset = dataset.rename_column('ctx', 'context') 23 | 24 | return dataset 25 | 26 | @property 27 | def templates(self) -> List[Tuple[str, str]]: 28 | return [ 29 | ("What happens next in this paragraph? {context}", "{answer}"), 30 | ("Continue writing the next sentence in this paragraph: {context}", "{answer}"), 31 | ("Continue writing the next sentence. {context}", "{answer}"), 32 | ("This is a test of commonsense. Complete the next sentence: {context}", "{answer}"), 33 | ("Write the next sentence in this paragraph: {context}", "{answer}"), 34 | ("How does the next paragraph end? {context}", "{answer}"), 35 | ("What most naturally follows? {context}", "{answer}"), 36 | ("What happens next? {context}", "{answer}"), 37 | ("What is the most logical next event? {context}", "{answer}"), 38 | ("Write the next sentence in the following story. {context}", "{answer}"), 39 | ] 40 | 41 | @property 42 | def possible_answers(self) -> Optional[List[str]]: 43 | return ['A', 'B', 'C', 'D'] 44 | 45 | @property 46 | def metric_name(self) -> str: 47 | return 'simple_accuracy' 48 | 49 | @property 50 | def task_name(self) -> str: 51 | return 'hellaswag' 52 | 53 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 54 | return to_letter(example['label']) 55 | -------------------------------------------------------------------------------- /src/tasks/mnli.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("mnli") 9 | class Mnli(BaseTask): 10 | 11 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 12 | if split != 'train': 13 | return None 14 | dataset: Dataset = load_dataset('glue', 'mnli', split=split) 15 | return dataset 16 | 17 | @property 18 | def templates(self) -> List[Tuple[str, str]]: 19 | return [ 20 | ( 21 | "Premise: \"{premise}\" Hypothesis: \"{hypothesis}\" Does the premise entail the hypothesis? Yes, No, or Maybe?", 22 | "{answer}"), 23 | ( 24 | "Premise: \"{premise}\" Hypothesis: \"{hypothesis}\" Is the hypothesis entailed by the premise? Yes, No, or Maybe?", 25 | "{answer}"), 26 | ( 27 | "Here is a premise: \"{premise}\" Here is a hypothesis: \"{hypothesis}\" Is it possible to conclude that if the premise is true, then so is the hypothesis? Yes, No, or Maybe?", 28 | "{answer}"), 29 | ( 30 | "Sentence 1: \"{premise}\" Sentence 2: \"{hypothesis}\" Is this second sentence entailed by the first sentence? Yes, No, or Maybe?", 31 | "{answer}"), 32 | ( 33 | "Sentence 1: \"{premise}\" Sentence 2: \"{hypothesis}\" If the first sentence is true, then is the second sentence true? Yes, No, or Maybe?", 34 | "{answer}"), 35 | ( 36 | "Based on the premise \"{premise}\", can we conclude the hypothesis \"{hypothesis}\" is true? Yes, No, or Maybe?", 37 | "{answer}"), 38 | ( 39 | "Premise: \"{premise}\" If this premise is true, what does that tell us about whether it entails the hypothesis \"{hypothesis}\"? Yes, No, or Maybe?", 40 | "{answer}"), 41 | ( 42 | "Premise: \"{premise}\" Based on this premise, is the hypothesis \"{hypothesis}\" true? Yes, No, or Maybe?", 43 | "{answer}"), 44 | ("If \"{premise}\", can we conclude that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"), 45 | ("\"{premise}\" Does it follow that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"), 46 | ] 47 | 48 | @property 49 | def possible_answers(self) -> Optional[List[str]]: 50 | return ['Yes', 'Maybe', 'No'] 51 | 52 | @property 53 | def metric_name(self) -> str: 54 | return 'simple_accuracy' 55 | 56 | @property 57 | def task_name(self) -> str: 58 | return 'mnli' 59 | 60 | 61 | @task_map.add("mnli_m") 62 | class Mnli_m(Mnli): 63 | 64 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 65 | if split == 'train': 66 | return None 67 | 68 | return load_dataset('glue', 'mnli_matched', split='validation') 69 | 70 | @property 71 | def task_name(self) -> str: 72 | return 'mnli_m' 73 | 74 | 75 | @task_map.add("mnli_mm") 76 | class Mnli_mm(Mnli): 77 | 78 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 79 | if split == 'train': 80 | return None 81 | 82 | return load_dataset('glue', 'mnli_mismatched', split='validation') 83 | 84 | @property 85 | def task_name(self) -> str: 86 | return 'mnli_mm' 87 | -------------------------------------------------------------------------------- /src/tasks/mrpc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("mrpc") 9 | class Mrpc(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('glue', 'mrpc', split=split) 13 | return dataset 14 | 15 | @property 16 | def templates(self) -> List[Tuple[str, str]]: 17 | return [ 18 | ("Here are two sentences: {sentence1} {sentence2} Do they have the same meaning?", "{answer}"), 19 | ( 20 | "Here are two sentences: {sentence1} {sentence2} Are the two sentences saying the same thing?", 21 | "{answer}"), 22 | ("{sentence1} {sentence2} Do the above sentences mean the same thing?", "{answer}"), 23 | ("{sentence1} {sentence2} Please tell me if the sentences above mean the same.", "{answer}"), 24 | ("{sentence1} {sentence2} Are these sentences conveying the same meaning?", "{answer}"), 25 | ("{sentence1} {sentence2} If the first sentence is true, is the second one also true?", "{answer}"), 26 | ("{sentence1} {sentence2} Are these two sentences paraphrases of each other?", "{answer}"), 27 | ("Do the following two sentences have the same meaning? {sentence1} {sentence2}", "{answer}"), 28 | ("Do these two sentences mean the same thing? {sentence1} {sentence2}", "{answer}"), 29 | ("Do these sentences have the same meaning? {sentence1} {sentence2}", "{answer}"), 30 | ] 31 | 32 | @property 33 | def possible_answers(self) -> Optional[List[str]]: 34 | return ['No', 'Yes'] 35 | 36 | @property 37 | def metric_name(self) -> str: 38 | return 'acc_and_f1' 39 | 40 | @property 41 | def task_name(self) -> str: 42 | return 'mrpc' 43 | -------------------------------------------------------------------------------- /src/tasks/multirc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("multirc") 9 | class Multirc(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('super_glue', 'multirc', split=split) 13 | dataset = dataset.rename_column('answer', 'response') 14 | 15 | return dataset 16 | 17 | @property 18 | def templates(self) -> List[Tuple[str, str]]: 19 | return [ 20 | ( 21 | "{paragraph} Question: \"{question}\" Response: \"{response}\" Does the response correctly answer the question?", 22 | "{answer}"), 23 | ( 24 | "{paragraph} Question: \"{question}\" Response: \"{response}\" Based on the paragraph, is the response to the question is factually correct?", 25 | "{answer}"), 26 | ("{paragraph} Question: \"{question}\" Answer: \"{response}\" Is this answer correct?", "{answer}"), 27 | ( 28 | "Paragraph: {paragraph} Question: \"{question}\" Answer: \"{response}\" Based on the paragraph, is this answer correct", 29 | "{answer}"), 30 | ( 31 | "{paragraph} Based on the paragraph, does the response \"{response}\" correctly answer the question \"{question}\"?", 32 | "{answer}"), 33 | ( 34 | "{paragraph} According to the above paragraph, the correct answer to the question \"{question}\" is \"{response}\"?", 35 | "{answer}"), 36 | ( 37 | "{paragraph} After reading the above, is \"{response}\" the correct answer to the question \"{question}\"?", 38 | "{answer}"), 39 | ("{paragraph} Question: \"{question}\" Answer: \"{response}\" Is this answer to the question correct?", 40 | "{answer}"), 41 | ] 42 | 43 | @property 44 | def possible_answers(self) -> Optional[List[str]]: 45 | return ['No', 'Yes'] 46 | 47 | @property 48 | def metric_name(self) -> str: 49 | return 'f1' 50 | 51 | @property 52 | def task_name(self) -> str: 53 | return 'multirc' 54 | -------------------------------------------------------------------------------- /src/tasks/nq.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("natural_questions") 9 | class Natural_questions(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('nq_open', split=split) 13 | dataset = dataset.map(lambda ex: {'question': ex['question'] + '?'}) 14 | return dataset 15 | 16 | @property 17 | def templates(self) -> List[Tuple[str, str]]: 18 | return [ 19 | ("Question: {question} Answer:", "{answer}"), 20 | ("{question}", "{answer}"), 21 | ("Answer the following question: {question}", "{answer}"), 22 | ("Answer this question: {question}", "{answer}"), 23 | ("Please answer this question: {question}", "{answer}"), 24 | ("Answer the question...{question}", "{answer}"), 25 | ("What is the answer to this question? {question}", "{answer}"), 26 | ("Can you tell me the answer to {question}", "{answer}"), 27 | ("Next question: {question}", "{answer}"), 28 | ("Q: {question} A:", "{answer}"), 29 | ] 30 | 31 | @property 32 | def possible_answers(self) -> Optional[List[str]]: 33 | return None 34 | 35 | @property 36 | def metric_name(self) -> str: 37 | return 'trivia_qa' 38 | 39 | @property 40 | def task_name(self) -> str: 41 | return 'natural_questions' 42 | 43 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 44 | return example['answer'] 45 | -------------------------------------------------------------------------------- /src/tasks/openbookqa.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map, to_letter 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("openbookqa") 9 | class Openbookqa(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'test' 12 | dataset = load_dataset('openbookqa', 'additional', split=split) 13 | 14 | dataset = dataset.rename_column('fact1', 'fact') 15 | dataset = dataset.rename_column('question_stem', 'question') 16 | 17 | def _map_func(ex: Dict) -> Dict: 18 | ex['options'] = ex['choices']['text'] 19 | return ex 20 | 21 | dataset = dataset.map(_map_func) 22 | 23 | return dataset 24 | 25 | @property 26 | def templates(self) -> List[Tuple[str, str]]: 27 | return [ 28 | ("{fact} {question}", "{answer}"), 29 | ("Read this fact: \"{fact}\" Now answer this question: \"{question}\"", "{answer}"), 30 | ("Given the fact \"{fact}\", what is the answer to the question or completion \"{question}\"", "{answer}"), 31 | ("Knowing that \"{fact}\", how would one answer \"{question}\"", "{answer}"), 32 | ("Use evidence from the fact that {fact} to answer this question: \"{question}\"", "{answer}"), 33 | ("Fact: {fact} Question: {question} What's the answer?", "{answer}"), 34 | ("Use this fact to answer the question: {fact} {question}", "{answer}"), 35 | ] 36 | 37 | @property 38 | def possible_answers(self) -> Optional[List[str]]: 39 | return ['A', 'B', 'C', 'D'] 40 | 41 | @property 42 | def metric_name(self) -> str: 43 | return 'simple_accuracy' 44 | 45 | @property 46 | def task_name(self) -> str: 47 | return 'openbookqa' 48 | 49 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 50 | return to_letter(example['answerKey']) 51 | -------------------------------------------------------------------------------- /src/tasks/paws.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("paws") 9 | class Paws(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'test' 12 | dataset = load_dataset('paws', 'labeled_final', split=split) 13 | return dataset 14 | 15 | @property 16 | def templates(self) -> List[Tuple[str, str]]: 17 | return [ 18 | ("{sentence1} {sentence2} Do these sentences mean the same thing?", "{answer}"), 19 | ("{sentence1} {sentence2} Are these two sentences paraphrases of each other?", "{answer}"), 20 | ("1. {sentence1} 2. {sentence2} Are these two sentences paraphrases of each other?", "{answer}"), 21 | ("(1) {sentence1} (2) {sentence2} Do these two sentences mean the same thing?", "{answer}"), 22 | ("Sentence 1: {sentence1} Sentence 2: {sentence2} Do these two sentences convey the same information?", 23 | "{answer}"), 24 | ("Do these two sentences from wikipedia have the same meaning? {sentence1} {sentence2}", "{answer}"), 25 | ("Same meaning? {sentence1} {sentence2}", "{answer}"), 26 | ("Are these paraphrases? {sentence1} {sentence2}", "{answer}"), 27 | ("Do these mean the same? {sentence1} {sentence2}", "{answer}"), 28 | ( 29 | "Please check if these have the same meaning. Answer \"yes\" if they do, otherwise \"no\". {sentence1} {sentence2}", 30 | "{answer}"), 31 | ] 32 | 33 | @property 34 | def possible_answers(self) -> Optional[List[str]]: 35 | return ['No', 'Yes'] 36 | 37 | @property 38 | def metric_name(self) -> str: 39 | return 'simple_accuracy' 40 | 41 | @property 42 | def task_name(self) -> str: 43 | return 'paws' 44 | -------------------------------------------------------------------------------- /src/tasks/piqa.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map, to_letter 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("piqa") 9 | class Piqa(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('piqa', split=split) 13 | 14 | def _map_func(ex: Dict) -> Dict: 15 | ex['options'] = [ex['sol1'], ex['sol2']] 16 | return ex 17 | 18 | dataset = dataset.map(_map_func) 19 | 20 | return dataset 21 | 22 | @property 23 | def templates(self) -> List[Tuple[str, str]]: 24 | return [ 25 | ("Here is a goal: \"{goal}\" How would you accomplish this goal?", "{answer}"), 26 | ("Here is a goal: \"{goal}\" Which way makes more sense to accomplish this goal?", "{answer}"), 27 | ("Goal: \"{goal}\" Which of the following methods is more reasonable for accomplishing this goal?", 28 | "{answer}"), 29 | ("BaseTaskive: \"{goal}\" Which of the following solutions is more sound in terms of naive physics reasoning?", 30 | "{answer}"), 31 | ("How do you do this: \"{goal}\"", "{answer}"), 32 | ("What is the best way to: \"{goal}\"", "{answer}"), 33 | ("Which of the following solutions is better for the following goal: \"{goal}\"", "{answer}"), 34 | ("How would someone go about accomplishing this goal? \"{goal}\"", "{answer}"), 35 | ] 36 | 37 | @property 38 | def possible_answers(self) -> Optional[List[str]]: 39 | return ['A', 'B'] 40 | 41 | @property 42 | def metric_name(self) -> str: 43 | return 'simple_accuracy' 44 | 45 | @property 46 | def task_name(self) -> str: 47 | return 'piqa' 48 | 49 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 50 | return to_letter(example['label']) 51 | -------------------------------------------------------------------------------- /src/tasks/qnli.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("qnli") 9 | class Qnli(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('glue', 'qnli', split=split) 13 | return dataset 14 | 15 | @property 16 | def templates(self) -> List[Tuple[str, str]]: 17 | return [ 18 | ("Does the sentence \"{sentence}\" answer the question \"{question}\"?", "{answer}"), 19 | ("Does the sentence \"{sentence}\" provide a valid answer to the question \"{question}\"?", "{answer}"), 20 | ("Is \"{sentence}\" a good answer to the question \"{question}\"?", "{answer}"), 21 | ("Does \"{sentence}\" correctly answer the question of \"{question}\"?", "{answer}"), 22 | ("Does \"{sentence}\" contain the correct answer to \"{question}\"?", "{answer}"), 23 | ("Q: {question} A: {sentence} Does the answer correctly answer the question?", "{answer}"), 24 | ( 25 | "Question: {question} Answer: {sentence} Is the question answered in a satisfactory fashion?", 26 | "{answer}"), 27 | ("Question: {question} Is {sentence} a good answer to this question?", "{answer}"), 28 | ("Question: {question} Is \"{sentence}\" the correct answer?", "{answer}"), 29 | ] 30 | 31 | @property 32 | def possible_answers(self) -> Optional[List[str]]: 33 | return ['Yes', 'No'] 34 | 35 | @property 36 | def metric_name(self) -> str: 37 | return 'simple_accuracy' 38 | 39 | @property 40 | def task_name(self) -> str: 41 | return 'qnli' 42 | -------------------------------------------------------------------------------- /src/tasks/qqp.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("qqp") 9 | class Qqp(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('glue', 'qqp', split=split) 13 | 14 | def _map_func(ex: Dict) -> Dict: 15 | ex['question1'] = ex['question1'].replace('""', '\'') 16 | ex['question2'] = ex['question2'].replace('""', '\'') 17 | return ex 18 | 19 | dataset = dataset.map(_map_func) 20 | return dataset 21 | 22 | @property 23 | def templates(self) -> List[Tuple[str, str]]: 24 | return [ 25 | ("\"{question1}\" \"{question2}\" Would you say that these questions are the same?", "{answer}"), 26 | ("\"{question1}\" \"{question2}\" Do those questions have the same meaning?", "{answer}"), 27 | ("\"{question1}\" \"{question2}\" Are these two questions inquiring about the same information?", "{answer}"), 28 | ("\"{question1}\" \"{question2}\" Please tell me if those questions are the same.", "{answer}"), 29 | ("\"{question1}\" \"{question2}\" Are these two questions paraphrases of each other?", "{answer}"), 30 | ("First question: \"{question1}\" Second question: \"{question2}\" Are these two questions asking the same thing?", 31 | "{answer}"), 32 | ( 33 | "Question 1: \"{question1}\" Question 2: \"{question2}\" Are questions 1 and 2 asking the same thing?", 34 | "{answer}"), 35 | ("Question 1: \"{question1}\" Question 2: \"{question2}\" Would the answer to these two questions be the same?", 36 | "{answer}"), 37 | ("Are the following two questions the same? \"{question1}\" \"{question2}\"", "{answer}"), 38 | ("Do these questions have the same meaning? \"{question1}\" \"{question2}\"", "{answer}"), 39 | ] 40 | 41 | @property 42 | def possible_answers(self) -> Optional[List[str]]: 43 | return ['No', 'Yes'] 44 | 45 | @property 46 | def metric_name(self) -> str: 47 | return 'acc_and_f1' 48 | 49 | @property 50 | def task_name(self) -> str: 51 | return 'qqp' 52 | -------------------------------------------------------------------------------- /src/tasks/rte.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("rte") 9 | class Rte(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('super_glue', 'rte', split=split) 13 | return dataset 14 | 15 | @property 16 | def templates(self) -> List[Tuple[str, str]]: 17 | return [ 18 | ("{premise} Based on the paragraph above can we conclude that \"{hypothesis}\"? Yes or No?", 19 | "{answer}"), 20 | ( 21 | "{premise} Based on that paragraph can we conclude that this sentence is true? {hypothesis} Yes or No?", 22 | "{answer}"), 23 | ("{premise} Can we draw the following conclusion? {hypothesis} Yes or No?", "{answer}"), 24 | ("{premise} Does this next sentence follow, given the preceding text? {hypothesis} Yes or No?", 25 | "{answer}"), 26 | ("{premise} Can we infer the following? {hypothesis} Yes or No?", "{answer}"), 27 | ( 28 | "Read the following paragraph and determine if the hypothesis is true: {premise} Hypothesis: {hypothesis} Yes or No?", 29 | "{answer}"), 30 | ("Read the text and determine if the sentence is true: {premise} Sentence: {hypothesis} Yes or No?", 31 | "{answer}"), 32 | ( 33 | "Can we draw the following hypothesis from the context? Context: {premise} Hypothesis: {hypothesis} Yes or No?", 34 | "{answer}"), 35 | ("Determine if the sentence is true based on the text below: {hypothesis} {premise} Yes or No?", 36 | "{answer}"), 37 | ] 38 | 39 | @property 40 | def possible_answers(self) -> Optional[List[str]]: 41 | return ['Yes', 'No'] 42 | 43 | @property 44 | def metric_name(self) -> str: 45 | return 'simple_accuracy' 46 | 47 | @property 48 | def task_name(self) -> str: 49 | return 'rte' 50 | -------------------------------------------------------------------------------- /src/tasks/sentiment140.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("sentiment140") 9 | class Sentiment140(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'test' 12 | dataset = load_dataset('sentiment140', split=split) 13 | 14 | def _map_func(ex: Dict) -> Dict: 15 | ex['label'] = 0 if int(ex['sentiment']) == 0 else 1 16 | return ex 17 | 18 | dataset = dataset.filter(lambda ex: int(ex['sentiment']) in [0, 4]) 19 | dataset = dataset.map(_map_func) 20 | 21 | return dataset 22 | 23 | @property 24 | def templates(self) -> List[Tuple[str, str]]: 25 | return [ 26 | ("{text} What is the sentiment of this tweet?", "{answer}"), 27 | ("{text} How would the sentiment of this tweet be described?", "{answer}"), 28 | ("{text} Describe the sentiment embodied by this tweet.", "{answer}"), 29 | ("Tweet: {text} Predict the sentiment of this tweet.", "{answer}"), 30 | ("What is the sentiment of the following tweet? Tweet:{text}", "{answer}"), 31 | ("How would one describe the sentiment of this tweet? {text}", "{answer}"), 32 | ] 33 | 34 | @property 35 | def possible_answers(self) -> Optional[List[str]]: 36 | return ['Negative', 'Positive'] 37 | 38 | @property 39 | def metric_name(self) -> str: 40 | # Prior work uses two classes 41 | # (https://www.aclweb.org/anthology/C14-1008.pdf, 42 | # https://arxiv.org/pdf/1404.2188.pdf) 43 | return 'simple_accuracy' 44 | 45 | @property 46 | def task_name(self) -> str: 47 | return 'sentiment140' 48 | -------------------------------------------------------------------------------- /src/tasks/snli.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | # define your task class 9 | @task_map.add("snli") # add your task to the task map 10 | class Snli(BaseTask): 11 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 12 | split = split if split == 'train' else 'test' 13 | dataset = load_dataset('snli', split=split) 14 | dataset = dataset.filter(lambda ex: int(ex["label"]) >= 0) 15 | 16 | return dataset 17 | 18 | @property 19 | def templates(self) -> List[Tuple[str, str]]: 20 | return [ 21 | ("If \"{premise}\", does this mean that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"), 22 | ("If \"{premise}\", can we conclude \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"), 23 | ("If \"{premise}\", does it logically follow that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"), 24 | ( 25 | "Based on the sentence \"{premise}\", is the sentence \"{hypothesis}\" a true sentence? Yes, No, or Maybe?", 26 | "{answer}"), 27 | ( 28 | "Premise: {premise} Hypothesis: {hypothesis} Can we conclude that the hypothesis is true if the premise is true? Yes, No, or Maybe?", 29 | "{answer}"), 30 | ( 31 | "Premise: {premise} Hypothesis: {hypothesis} Given the premise, can we conclude the hypothesis? Yes, No, or Maybe?", 32 | "{answer}"), 33 | ( 34 | "Here is a premise: \"{premise}\" Here is a hypothesis: \"{hypothesis}\". Does the premise tell us whether the hypothesis is true? Yes, No, or Maybe?", 35 | "{answer}"), 36 | ("Is it possible to conclude that \"{premise}\" if \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"), 37 | ("Is the premise \"{premise}\" true if \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"), 38 | ] 39 | 40 | 41 | @property 42 | def possible_answers(self) -> Optional[List[str]]: 43 | return ['Yes', 'Maybe', 'No'] 44 | 45 | @property 46 | def metric_name(self) -> str: 47 | return 'simple_accuracy' 48 | 49 | @property 50 | def task_name(self) -> str: 51 | return 'snli' 52 | -------------------------------------------------------------------------------- /src/tasks/squad_v1.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from typing import Optional, List, Tuple, Dict, Union 4 | from datasets import load_dataset, Dataset 5 | 6 | from tasks import task_map 7 | from tasks.base_task import BaseTask 8 | 9 | 10 | @task_map.add("squad_v1") 11 | class Squad_v1(BaseTask): 12 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 13 | split = split if split == 'train' else 'validation' 14 | dataset = load_dataset('squad', split=split) 15 | 16 | def _map_func(ex: Dict) -> Dict: 17 | ex['title'] = re.sub(r'_', ' ', ex['title']) 18 | return ex 19 | 20 | dataset = dataset.map(_map_func) 21 | 22 | return dataset 23 | 24 | @property 25 | def templates(self) -> List[Tuple[str, str]]: 26 | return [ 27 | ("Please answer a question about the following article about {title}: {context} {question}", "{answer}"), 28 | ("Read this and answer the question {context} {question}", "{answer}"), 29 | ("{context} {question}", "{answer}"), 30 | ("Answer a question about this article: {context} {question}", "{answer}"), 31 | ("Here is a question about this article: {context} What is the answer to this question: {question}", 32 | "{answer}"), 33 | ("Article: {context} Question: {question}", "{answer}"), 34 | ("Article: {context} Now answer this question: {question}", "{answer}"), 35 | ("{title} {context} Q: {question}", "{answer}"), 36 | ] 37 | 38 | @property 39 | def possible_answers(self) -> Optional[List[str]]: 40 | return None 41 | 42 | @property 43 | def metric_name(self) -> str: 44 | return 'squad' 45 | 46 | @property 47 | def task_name(self) -> str: 48 | return 'squad_v1' 49 | 50 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 51 | return example['answers']['text'] 52 | -------------------------------------------------------------------------------- /src/tasks/sst2.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("sst2") 9 | class Sst2(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('sst2', split=split) 13 | return dataset 14 | 15 | @property 16 | def templates(self) -> List[Tuple[str, str]]: 17 | return [ 18 | ("Review: \"{sentence}\" Is this movie review sentence negative or positive?", "{answer}"), 19 | ("Short movie review: \"{sentence}\" Did the critic thinking positively or negatively of the movie?", 20 | "{answer}"), 21 | ( 22 | "Sentence from a movie review: \"{sentence}\" Was the movie seen positively or negatively based on the preceding review?", 23 | "{answer}"), 24 | ("\"{sentence}\" How would the sentiment of this sentence be perceived?", "{answer}"), 25 | ("Is the sentiment of the following sentence positive or negative? \"{sentence}\"", "{answer}"), 26 | ("What is the sentiment of the following movie review sentence? \"{sentence}\"", "{answer}"), 27 | ("Would the following phrase be considered positive or negative? \"{sentence}\"", "{answer}"), 28 | ("Does the following review have a positive or negative opinion of the movie? \"{sentence}\"", "{answer}"), 29 | ] 30 | 31 | @property 32 | def possible_answers(self) -> Optional[List[str]]: 33 | return ['Negative', 'Positive'] 34 | 35 | @property 36 | def metric_name(self) -> str: 37 | return 'simple_accuracy' 38 | 39 | @property 40 | def task_name(self) -> str: 41 | return 'sst2' 42 | -------------------------------------------------------------------------------- /src/tasks/winogrande.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map, to_letter 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("winogrande") 9 | class Winogrande(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('winogrande', 'winogrande_xl', split=split) 13 | 14 | def _map_func(ex: Dict) -> Dict: 15 | cut_index = ex["sentence"].index('_') 16 | context = ex["sentence"][:cut_index] 17 | ex['context'] = context.strip() 18 | 19 | text_second = ex["sentence"][cut_index + 1:] 20 | ex['options'] = [ex["option1"] + text_second, ex["option2"] + text_second] 21 | 22 | return ex 23 | 24 | dataset = dataset.map(_map_func) 25 | 26 | return dataset 27 | 28 | @property 29 | def templates(self) -> List[Tuple[str, str]]: 30 | return [ 31 | ("How does the sentence end? {context}", "{answer}"), 32 | ("Write the next sentence. {context}", "{answer}"), 33 | ("Continue the following story. {context}", "{answer}"), 34 | ("Complete the following sentence. {context}", "{answer}"), 35 | ("Continue writing the following text. {context}", "{answer}"), 36 | ("How does the sentence end? {context}", "{answer}"), 37 | ("Write the next sentence. {context}", "{answer}"), 38 | ("Continue the following story. {context}", "{answer}"), 39 | ("Complete the following sentence. {context}", "{answer}"), 40 | ("Continue writing the following text. {context}", "{answer}"), 41 | ] 42 | 43 | @property 44 | def possible_answers(self) -> Optional[List[str]]: 45 | return ['A', 'B'] 46 | 47 | @property 48 | def metric_name(self) -> str: 49 | return 'simple_accuracy' 50 | 51 | @property 52 | def task_name(self) -> str: 53 | return 'winogrande' 54 | 55 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 56 | label = int(example['answer']) - 1 57 | return to_letter(label) 58 | -------------------------------------------------------------------------------- /src/tasks/wsc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("wsc") 9 | class Wsc(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | split = split if split == 'train' else 'validation' 12 | dataset = load_dataset('super_glue', 'wsc', split=split) 13 | 14 | dataset = dataset.rename_column('text', 'context') 15 | dataset = dataset.rename_column('span1_text', 'text1') 16 | dataset = dataset.rename_column('span2_text', 'text2') 17 | 18 | return dataset 19 | 20 | @property 21 | def templates(self) -> List[Tuple[str, str]]: 22 | return [ 23 | ("{context} Are \"{text1}\" and \"{text2}\" the same entity?", "{answer}"), 24 | ("{context} Do \"{text1}\" and \"{text2}\" have the same meaning?", "{answer}"), 25 | ("Given the following context {context} Are \"{text1}\" and \"{text2}\" the same?", "{answer}"), 26 | ("{context} Do \"{text2}\" and \"{text1}\" mean the same thing?", "{answer}"), 27 | ("{context} Are \"{text2}\" and \"{text1}\" the same thing in the aforementioned sentence?", "{answer}"), 28 | ("Context:{context} Is \"{text2}\" the same as \"{text1}\"?", "{answer}"), 29 | ("Consider this sentence: {context} Are \"{text2}\" and \"{text1}\" the same?", "{answer}"), 30 | ("Are \"{text1}\" and \"{text2}\" the same in this sentence? {context}", "{answer}"), 31 | ("Is \"{text1}\" the same as \"{text2}\" in this sentence? {context}", "{answer}"), 32 | ("Do \"{text1}\" and \"{text2}\" point to the same thing in the following sentence? {context}", "{answer}"), 33 | ] 34 | 35 | @property 36 | def possible_answers(self) -> Optional[List[str]]: 37 | return ['No', 'Yes'] 38 | 39 | @property 40 | def metric_name(self) -> str: 41 | return 'simple_accuracy' 42 | 43 | @property 44 | def task_name(self) -> str: 45 | return 'wsc' 46 | -------------------------------------------------------------------------------- /src/tasks/wsc273.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Dict, Union 2 | from datasets import load_dataset, Dataset 3 | 4 | from tasks import task_map, to_letter 5 | from tasks.base_task import BaseTask 6 | 7 | 8 | @task_map.add("wsc273") 9 | class Wsc273(BaseTask): 10 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 11 | if split == 'train': 12 | return None 13 | 14 | dataset = load_dataset('winograd_wsc', 'wsc273', split='test') 15 | 16 | def _map_func(ex: Dict) -> Dict: 17 | text_first = ex["text"][:ex["pronoun_loc"]] 18 | ex['context'] = text_first 19 | 20 | text_second = ex["text"][ex["pronoun_loc"] + len(ex["pronoun"]):] 21 | ex['options'] = [ex["options"][0] + text_second, ex["options"][1] + text_second] 22 | 23 | return ex 24 | 25 | dataset = dataset.map(_map_func) 26 | 27 | return dataset 28 | 29 | @property 30 | def templates(self) -> List[Tuple[str, str]]: 31 | return [ 32 | ("{context}", "{answer}"), 33 | ("Complete the passage. {context}", "{answer}"), 34 | ("How does this following sentence end? {context}", "{answer}"), 35 | ("What is the most logical completion for the following text? {context}", "{answer}"), 36 | ("How does this text end? {context}", "{answer}"), 37 | ("What happens next? {context}", "{answer}"), 38 | ("Complete the following sentence. {context}", "{answer}"), 39 | ("Fill in the remainder of the sentence. {context}", "{answer}"), 40 | ("What is the next event? {context}", "{answer}"), 41 | ("Complete the rest of the sentence. {context}", "{answer}"), 42 | ] 43 | 44 | @property 45 | def possible_answers(self) -> Optional[List[str]]: 46 | return ['A', 'B'] 47 | 48 | @property 49 | def metric_name(self) -> str: 50 | return 'simple_accuracy' 51 | 52 | @property 53 | def task_name(self) -> str: 54 | return 'wsc273' 55 | 56 | def get_answer(self, example: Dict) -> Union[str, List[str]]: 57 | return to_letter(example['label']) 58 | -------------------------------------------------------------------------------- /src/tasks/yelp.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from typing import Optional, List, Tuple, Dict 4 | from datasets import load_dataset, Dataset 5 | 6 | from tasks import task_map 7 | from tasks.base_task import BaseTask 8 | 9 | 10 | @task_map.add("yelp") 11 | class Yelp(BaseTask): 12 | def _load_raw_data(self, split: str) -> Optional[Dataset]: 13 | split = split if split == 'train' else 'test' 14 | dataset = load_dataset('yelp_polarity', split=split) 15 | 16 | def _map_func(ex: Dict) -> Dict: 17 | ex['text'] = re.sub(r'\\\"', '', ex['text']) 18 | ex['text'] = re.sub(r'\\n\\n', ' ', ex['text']) 19 | return ex 20 | 21 | dataset = dataset.map(_map_func) 22 | dataset = dataset.filter(lambda ex: 0 < len(ex['text'].split()) <= 256) 23 | 24 | return dataset 25 | 26 | @property 27 | def templates(self) -> List[Tuple[str, str]]: 28 | return [ 29 | ("{text} Is this review positive or negative?", "{answer}"), 30 | ("{text} What is the sentiment of this review?", "{answer}"), 31 | ("{text} Was this review given positively or negatively?", "{answer}"), 32 | ("{text} How would this review be described in terms of sentiment?", "{answer}"), 33 | ("Is the following review positive or negative? {text}", "{answer}"), 34 | ("What is the sentiment of the following review? {text}", "{answer}"), 35 | ("How might one describe the sentiment of this review? {text}", "{answer}"), 36 | ] 37 | 38 | @property 39 | def possible_answers(self) -> Optional[List[str]]: 40 | return ['Negative', 'Positive'] 41 | 42 | @property 43 | def metric_name(self) -> str: 44 | return 'simple_accuracy' 45 | 46 | @property 47 | def task_name(self) -> str: 48 | return 'yelp' 49 | -------------------------------------------------------------------------------- /src/train_biencoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info, set_verbosity_warning 4 | from transformers.trainer_callback import PrinterCallback 5 | from transformers import ( 6 | AutoTokenizer, 7 | HfArgumentParser, 8 | Trainer, 9 | set_seed, 10 | PreTrainedTokenizerFast 11 | ) 12 | 13 | from logger_config import logger, LoggerCallback 14 | from config import Arguments 15 | from trainers import BiencoderTrainer 16 | from loaders import RetrievalDataLoader 17 | from collators import BiencoderCollator 18 | from models import BiencoderModel 19 | 20 | 21 | def _common_setup(args: Arguments): 22 | set_verbosity_info() 23 | if args.process_index > 0: 24 | logger.setLevel(logging.WARNING) 25 | set_verbosity_warning() 26 | enable_explicit_format() 27 | set_seed(args.seed) 28 | 29 | 30 | def main(): 31 | parser = HfArgumentParser((Arguments,)) 32 | args: Arguments = parser.parse_args_into_dataclasses()[0] 33 | _common_setup(args) 34 | logger.info('Args={}'.format(str(args))) 35 | 36 | tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path) 37 | model: BiencoderModel = BiencoderModel.build(args=args) 38 | logger.info(model) 39 | logger.info('Vocab size: {}'.format(len(tokenizer))) 40 | 41 | data_collator = BiencoderCollator( 42 | args=args, 43 | tokenizer=tokenizer, 44 | pad_to_multiple_of=8 if args.fp16 else None) 45 | 46 | retrieval_data_loader = RetrievalDataLoader(args=args, tokenizer=tokenizer) 47 | train_dataset = retrieval_data_loader.train_dataset 48 | 49 | trainer: Trainer = BiencoderTrainer( 50 | model=model, 51 | args=args, 52 | train_dataset=train_dataset if args.do_train else None, 53 | data_collator=data_collator, 54 | tokenizer=tokenizer, 55 | ) 56 | trainer.remove_callback(PrinterCallback) 57 | trainer.add_callback(LoggerCallback) 58 | retrieval_data_loader.set_trainer(trainer) 59 | model.trainer = trainer 60 | 61 | if args.do_train: 62 | train_result = trainer.train() 63 | trainer.save_model() 64 | 65 | metrics = train_result.metrics 66 | metrics["train_samples"] = len(train_dataset) 67 | 68 | trainer.log_metrics("train", metrics) 69 | trainer.save_metrics("train", metrics) 70 | 71 | return 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /src/train_cross_encoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info, set_verbosity_warning 4 | from transformers.trainer_callback import PrinterCallback 5 | from transformers import ( 6 | AutoTokenizer, 7 | HfArgumentParser, 8 | Trainer, 9 | set_seed, 10 | PreTrainedTokenizerFast 11 | ) 12 | 13 | from logger_config import logger, LoggerCallback 14 | from config import Arguments 15 | from trainers.reward_trainer import RewardTrainer 16 | from loaders import CrossEncoderDataLoader 17 | from collators import CrossEncoderCollator 18 | from models import Reranker 19 | 20 | 21 | def _common_setup(args: Arguments): 22 | set_verbosity_info() 23 | if args.process_index > 0: 24 | logger.setLevel(logging.WARNING) 25 | set_verbosity_warning() 26 | enable_explicit_format() 27 | set_seed(args.seed) 28 | 29 | 30 | def main(): 31 | parser = HfArgumentParser((Arguments,)) 32 | args: Arguments = parser.parse_args_into_dataclasses()[0] 33 | _common_setup(args) 34 | logger.info('Args={}'.format(str(args))) 35 | 36 | tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path) 37 | 38 | model: Reranker = Reranker.from_pretrained( 39 | all_args=args, 40 | pretrained_model_name_or_path=args.model_name_or_path, 41 | num_labels=1) 42 | 43 | logger.info(model) 44 | logger.info('Vocab size: {}'.format(len(tokenizer))) 45 | 46 | data_collator = CrossEncoderCollator( 47 | tokenizer=tokenizer, 48 | pad_to_multiple_of=8 if args.fp16 else None) 49 | 50 | reward_data_loader = CrossEncoderDataLoader(args=args, tokenizer=tokenizer) 51 | train_dataset = reward_data_loader.train_dataset 52 | 53 | trainer: Trainer = RewardTrainer( 54 | model=model, 55 | args=args, 56 | train_dataset=train_dataset if args.do_train else None, 57 | data_collator=data_collator, 58 | tokenizer=tokenizer, 59 | ) 60 | trainer.remove_callback(PrinterCallback) 61 | trainer.add_callback(LoggerCallback) 62 | reward_data_loader.trainer = trainer 63 | 64 | if args.do_train: 65 | train_result = trainer.train() 66 | trainer.save_model() 67 | 68 | metrics = train_result.metrics 69 | metrics["train_samples"] = len(train_dataset) 70 | 71 | trainer.log_metrics("train", metrics) 72 | trainer.save_metrics("train", metrics) 73 | 74 | return 75 | 76 | 77 | if __name__ == "__main__": 78 | main() 79 | -------------------------------------------------------------------------------- /src/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .biencoder_trainer import BiencoderTrainer 2 | from .reward_trainer import RewardTrainer 3 | -------------------------------------------------------------------------------- /src/trainers/biencoder_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import Optional 4 | from transformers.trainer import Trainer 5 | 6 | from logger_config import logger 7 | from evaluation.metrics import accuracy, batch_mrr 8 | from models import BiencoderOutput, BiencoderModel 9 | from utils import AverageMeter 10 | 11 | 12 | class BiencoderTrainer(Trainer): 13 | def __init__(self, *pargs, **kwargs): 14 | super(BiencoderTrainer, self).__init__(*pargs, **kwargs) 15 | self.model: BiencoderModel 16 | 17 | self.acc1_meter = AverageMeter('Acc@1', round_digits=2) 18 | self.acc3_meter = AverageMeter('Acc@3', round_digits=2) 19 | self.mrr_meter = AverageMeter('mrr', round_digits=2) 20 | self.last_epoch = 0 21 | 22 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 23 | output_dir = output_dir if output_dir is not None else self.args.output_dir 24 | os.makedirs(output_dir, exist_ok=True) 25 | logger.info("Saving model checkpoint to {}".format(output_dir)) 26 | self.model.save(output_dir) 27 | if self.tokenizer is not None: 28 | self.tokenizer.save_pretrained(output_dir) 29 | 30 | def compute_loss(self, model, inputs, return_outputs=False): 31 | outputs: BiencoderOutput = model(inputs) 32 | loss = outputs.loss 33 | 34 | if self.model.training: 35 | step_acc1, step_acc3 = accuracy(output=outputs.scores.detach(), target=outputs.labels, topk=(1, 3)) 36 | step_mrr = batch_mrr(output=outputs.scores.detach(), target=outputs.labels) 37 | 38 | self.acc1_meter.update(step_acc1) 39 | self.acc3_meter.update(step_acc3) 40 | self.mrr_meter.update(step_mrr) 41 | 42 | if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0: 43 | logger.info('step: {}, {}, {}, {}'.format(self.state.global_step, self.mrr_meter, self.acc1_meter, self.acc3_meter)) 44 | 45 | self._reset_meters_if_needed() 46 | 47 | return (loss, outputs) if return_outputs else loss 48 | 49 | def _reset_meters_if_needed(self): 50 | if int(self.state.epoch) != self.last_epoch: 51 | self.last_epoch = int(self.state.epoch) 52 | self.acc1_meter.reset() 53 | self.acc3_meter.reset() 54 | self.mrr_meter.reset() 55 | -------------------------------------------------------------------------------- /src/trainers/reward_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import Optional 4 | from transformers.trainer import Trainer, IntervalStrategy 5 | from transformers.modeling_outputs import SequenceClassifierOutput 6 | 7 | from logger_config import logger 8 | from evaluation.metrics import accuracy 9 | from utils import AverageMeter 10 | 11 | 12 | class RewardTrainer(Trainer): 13 | 14 | def __init__(self, *pargs, **kwargs): 15 | super(RewardTrainer, self).__init__(*pargs, **kwargs) 16 | 17 | self.acc_meter = AverageMeter('acc', round_digits=2) 18 | self.last_epoch = 0 19 | 20 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 21 | output_dir = output_dir if output_dir is not None else self.args.output_dir 22 | os.makedirs(output_dir, exist_ok=True) 23 | logger.info("Saving model checkpoint to {}".format(output_dir)) 24 | 25 | self.model.save_pretrained(output_dir) 26 | 27 | if self.tokenizer is not None and self.is_world_process_zero(): 28 | self.tokenizer.save_pretrained(output_dir) 29 | 30 | def compute_loss(self, model, inputs, return_outputs=False): 31 | outputs: SequenceClassifierOutput = model(inputs) 32 | loss = outputs.loss 33 | 34 | if self.model.training: 35 | labels = inputs['labels'] 36 | step_acc = accuracy(output=outputs.logits.detach(), target=labels)[0] 37 | self.acc_meter.update(step_acc) 38 | if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0: 39 | logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter)) 40 | 41 | self._reset_meters_if_needed() 42 | 43 | return (loss, outputs) if return_outputs else loss 44 | 45 | def _reset_meters_if_needed(self): 46 | if int(self.state.epoch) != self.last_epoch: 47 | self.last_epoch = int(self.state.epoch) 48 | self.acc_meter.reset() 49 | if self.args.save_strategy == IntervalStrategy.STEPS and self.state.global_step % self.args.save_steps == 0: 50 | self.acc_meter.reset() 51 | --------------------------------------------------------------------------------