├── .DS_Store
├── README.md
├── conformal_risk_computation
├── conformal_distribution_shift_generation_risk.py
├── conformal_distribution_shift_generation_risk_comparison.py
├── conformal_generation_risk.py
├── conformal_generation_risk_comparison.py
├── conformal_generation_risk_multi_dim_config.py
├── conformal_generation_risk_num_gen.py
├── conformal_generation_risk_similarity_threshold.py
├── conformal_generation_risk_valid_config.py
└── utils.py
├── ds_config.json
├── figs
└── README.md
├── requirements.txt
├── run
└── conformal_parallel.sh
├── scripts_conformal_generation_risk
├── run_conformal_distribution_shift_generation_risk.sh
├── run_conformal_distribution_shift_generation_risk_comparisons.sh
├── run_conformal_generation_risk.sh
├── run_conformal_generation_risk_comparisons.sh
├── run_conformal_generation_risk_multi_dim_config.sh
├── run_conformal_generation_risk_num_gen.sh
├── run_conformal_generation_risk_similarity_threshold.sh
└── run_conformal_generation_risk_valid_config.sh
├── scripts_raw_risk_scores
├── baai.sh
├── bm25.sh
├── llm-r.sh
└── openai.sh
└── src
├── .DS_Store
├── collators
├── __init__.py
├── biencoder_collator.py
├── cross_encoder_collator.py
└── gpt2_collator.py
├── compute_controlled_risk.py
├── config.py
├── conformal_calibration_empirical_risk.py
├── conformal_calibration_guarantee.py
├── conformal_simulation_risks.py
├── data_utils.py
├── evaluation
├── __init__.py
├── baai_eval.py
├── base_eval.py
├── bm25_eval.py
├── dense_eval.py
├── metrics.py
├── openai_eval.py
├── qa_utils.py
└── random_eval.py
├── inference
├── __init__.py
├── gen_llm_scores.py
├── gen_reward_scores.py
├── generate_few_shot_prompt.py
├── inference_utils.py
└── search_topk.py
├── llm_calibrator.py
├── llm_evaluator.py
├── llm_simulator.py
├── llms
├── __init__.py
├── base_llm.py
├── gpt2.py
├── gpt_neo.py
└── llama.py
├── loaders
├── __init__.py
├── biencoder_dataloader.py
├── cross_encoder_dataloader.py
└── loader_utils.py
├── logger_config.py
├── main_eval.py
├── model_utils.py
├── models
├── __init__.py
├── biencoder_model.py
├── cross_encoder_model.py
├── simple_encoder.py
└── simple_retriever.py
├── tasks
├── __init__.py
├── aeslc.py
├── agnews.py
├── arc.py
├── base_task.py
├── boolq.py
├── common_gen.py
├── copa.py
├── dart.py
├── e2e_nlg.py
├── gigaword.py
├── hellaswag.py
├── mnli.py
├── mrpc.py
├── multirc.py
├── nq.py
├── openbookqa.py
├── paws.py
├── piqa.py
├── qnli.py
├── qqp.py
├── rte.py
├── sentiment140.py
├── snli.py
├── squad_v1.py
├── sst2.py
├── winogrande.py
├── wsc.py
├── wsc273.py
└── yelp.py
├── train_biencoder.py
├── train_cross_encoder.py
├── trainers
├── __init__.py
├── biencoder_trainer.py
└── reward_trainer.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kangmintong/C-RAG/5bda4382bebf6453b3b8dccb8a52f4d015e95fee/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # C-RAG: Certified Generation Risks for Retrieval-Augmented Language Models [ICML 2024]
2 |
3 | We provide the implementation of [C-RAG](https://arxiv.org/abs/2402.03181) in this repositary.
4 |
5 | C-RAG is the first framework to certify generation risks for RAG models. Specifically, C-RAG provides conformal risk analysis for RAG models and certify an upper confidence bound of generation risks, which is refered to as conformal generation risk.
6 | C-RAG also provides theoretical guarantees on conformal generation risks for general bounded risk functions under test distribution shifts.
7 | C-RAG proves that RAG achieves a lower conformal generation risk than that of a single LLM when the quality of the retrieval model and transformer is non-trivial.
8 | The intensive empirical results demonstrate the soundness and tightness of the conformal generation risk guarantees across four widely-used NLP datasets on four state-of-the-art retrieval models.
9 |
10 | ## Environment
11 |
12 | Install PyTorch with correponding environment and CUDA version at [Pytorch Installation](https://pytorch.org/get-started/locally/).
13 |
14 | Run ``pip install -r requirement.txt`` for installation of other neccessary packages in the repo.
15 |
16 | ## Pretrained models
17 | For the supervised-finetuned biencoder-based retrieval model, we follow the implementation in [LLM-R](https://arxiv.org/abs/2307.07164) and provide the model checkpoint at [trained_retrieval_models](https://drive.google.com/file/d/1xOeCz3vt2piHuY00a5q4YCNhDkyCs0VF/view?usp=sharing).
18 |
19 | Or you can download it by command:
20 | ```
21 | gdown https://drive.google.com/uc?id=1xOeCz3vt2piHuY00a5q4YCNhDkyCs0VF
22 | ```
23 |
24 | Then, put the folder ``trained_retrieval_models/`` under ``C-RAG/``.
25 |
26 | ## Dataset preparation
27 | We evaluate C-RAG on four widely used NLP datasets, including AESLC, CommonGen, DART, and E2E. We preprocess the data and provide it at [data](https://drive.google.com/file/d/1JJC192wdOmXYZy_hXcGVrXOtMK2LWsv7/view?usp=sharing).
28 |
29 | Or you can download it by command:
30 | ```
31 | gdown https://drive.google.com/uc?id=1JJC192wdOmXYZy_hXcGVrXOtMK2LWsv7
32 | ```
33 |
34 | Then, put the folder ``data/`` under ``C-RAG/``.
35 |
36 | ## Evaluate conformal generation risks in C-RAG
37 |
38 | To compute the conformal generation risk, we need to (1) evaluate the raw risk scores for calibration instances following our constrained generation protocol, and (2) compute the conformal generation risks based on empirical risk statistics.
39 |
40 | ### (1) Evaluate raw risk scores for calibration instances
41 |
42 | #### We compact the process in four scripts for four retrieval models
43 |
44 | Evaluate raw risk scores for BM25 retrieval model:
45 | ```
46 | sh scripts_raw_risk_scores/bm25.sh
47 | ```
48 |
49 | Evaluate raw risk scores for BAAI/bge retrieval model:
50 | ```
51 | sh scripts_raw_risk_scores/baai.sh
52 | ```
53 |
54 | Evaluate raw risk scores for OpenAI/text-embedding-ada-002 retrieval model:
55 | ```
56 | sh scripts_raw_risk_scores/openai.sh
57 | ```
58 |
59 | Evaluate raw risk scores for LLM-R finetuned biencoder-based retrieval model:
60 | ```
61 | sh scripts_raw_risk_scores/llm-r.sh
62 | ```
63 |
64 | #### Exaplanation: we compact the following two steps in the scripts above:
65 |
66 | 1. Prepare the prompt via ``src/inference/generate_few_shot_prompt.py``:
Retrieve relevant examples and store the prompts at `` outputs/{METHOD}/{METHOD}_test_k{N_RAG}.jsonl.gz``
67 | 2. Evaluate the risks of prompts on calibration sets via ``src/conformal_calibration_empirical_risk.py``:
Evaluate the prompts and store results in ``outputs/{METHOD}/{LLM}_{METHOD}/``
68 |
69 |
70 | ### (2) Compute conformal generation risks
71 |
72 | The conformal generation risk computation is based on empirical risk statistics stored at ``outputs/{METHOD}/{LLM}_{METHOD}/`` in step (1).
73 |
74 | #### Conformal generation risk without distribution shifts
75 | 1. Compute conformal generation risks of a single retrieval model and compare it with the simulation results:
76 | ```
77 | sh scripts_conformal_generation_risk/run_conformal_generation_risk.sh
78 | ```
79 | 2. Compare conformal generation risks of different retrieval models (after running step 1 for corresponding methods):
80 | ```
81 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_comparisons.sh
82 | ```
83 |
84 | #### Conformal generation risk with distribution shifts
85 | 1. Compute conformal generation risks of a single retrieval model and compare it with the simulation results:
86 | ```
87 | sh scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk.sh
88 | ```
89 | 2. Compare conformal generation risks of different retrieval models (after running step 1 for corresponding methods):
90 | ```
91 | sh scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk_comparisons.sh
92 | ```
93 |
94 | #### Conformal generation risk with multi-dimensional RAG configurations
95 | ```
96 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_multi_dim_config.sh
97 | ```
98 |
99 | #### Valid configurations given desired risk levels
100 | ```
101 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_valid_config.sh
102 | ```
103 |
104 | #### Additional evaluations with varying RAG configurations
105 |
106 | Conformal generation risks with varying generation set sizes:
107 | ```
108 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_num_gen.sh
109 | ```
110 |
111 | Conformal generation risks with varying generation similar thresholds:
112 | ```
113 | sh scripts_conformal_generation_risk/run_conformal_generation_risk_similarity_threshold.sh
114 | ```
115 |
116 |
117 | ## Acknowledgement
118 |
119 | The inference part in the repo is built on [LLM-R repo](https://github.com/microsoft/LMOps/tree/main/llm_retriever).
120 |
121 | For any related questions or discussion, please contact ``mintong2@illinois.edu``.
122 |
123 | If you find our paper and repo useful for your research, please consider cite:
124 | ```
125 | @article{kang2024c,
126 | title={C-RAG: Certified Generation Risks for Retrieval-Augmented Language Models},
127 | author={Kang, Mintong and G{\"u}rel, Nezihe Merve and Yu, Ning and Song, Dawn and Li, Bo},
128 | journal={arXiv preprint arXiv:2402.03181},
129 | year={2024}
130 | }
131 | ```
132 |
--------------------------------------------------------------------------------
/conformal_risk_computation/conformal_distribution_shift_generation_risk_comparison.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | import numpy as np
6 | from matplotlib.ticker import MaxNLocator
7 | from tqdm import tqdm
8 | from scipy.stats import binom
9 | from typing import List, Union, Optional, Tuple, Mapping, Dict
10 | import os
11 | from numpy import linalg
12 | import numpy as np
13 | import argparse
14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, cal_dist_shift_bound, get_color_dict, get_max_x_all
15 |
16 | if __name__=='__main__':
17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.')
18 |
19 | parser.add_argument('--llm_retrieval_methods', '--names-list', nargs='+')
20 | parser.add_argument('--datasets', nargs='+')
21 | parser.add_argument('--n_rag', type=int, default=15)
22 | parser.add_argument('--num_gen', type=int, default=1)
23 | parser.add_argument('--gen_thresh', type=int, default=0)
24 |
25 | args = parser.parse_args()
26 |
27 | method2name = get_method2name()
28 | method2method = get_method2method()
29 |
30 | datasets = args.datasets
31 | retrieval_methods = args.llm_retrieval_methods
32 |
33 | max_x_all = get_max_x_all()
34 |
35 | n_rag = args.n_rag
36 | num_gen = args.num_gen
37 | gen_thresh = args.gen_thresh
38 |
39 | for dataset in datasets:
40 | max_x = max_x_all[dataset]
41 | hellinger_distances = np.array(list(range(0, max_x, 1)))
42 | hellinger_distances = hellinger_distances / 100.0
43 |
44 | alphas = {}
45 | min_risk, max_risk = 10e5, -10e5
46 |
47 | for method in retrieval_methods:
48 |
49 | alpha_list = []
50 | evaluate_results = json.load(open(f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{gen_thresh}_calibration_result.json','r', encoding='utf-8'))
51 |
52 |
53 | plot_dist_shift_results = {}
54 |
55 |
56 | results = list(evaluate_results.values())[0][0]
57 | results = np.array(results)
58 | r_hat = np.mean(results)
59 | r_hat = 1. - r_hat
60 |
61 | for dist in tqdm(hellinger_distances):
62 |
63 | temp_res = cal_dist_shift_bound(r_hat, dist, len(results), np.var(results), 0.1)
64 | plot_dist_shift_results[dist] = temp_res
65 | alpha_list.append(temp_res)
66 |
67 | if temp_res < min_risk:
68 | min_risk = temp_res
69 | if temp_res > max_risk:
70 | max_risk = temp_res
71 |
72 | alphas[method] = alpha_list
73 |
74 | data = {
75 | 'dist': hellinger_distances,
76 | }
77 | for method in retrieval_methods:
78 | data[method] = alphas[method]
79 |
80 | df = pd.DataFrame(data)
81 |
82 | plt.style.use('seaborn-v0_8-darkgrid')
83 | plt.figure(figsize=(10, 8))
84 |
85 | color_dict = get_color_dict(retrieval_methods)
86 |
87 | for metric, color in color_dict.items():
88 | plt.plot(df['dist'], df[metric], marker='*', color=color, label=method2name[metric], markersize=20)
89 |
90 | fontsize = 32
91 | plt.ylim([min_risk - 0.07, min(max_risk + 0.02, 1.02)])
92 | plt.xticks(fontsize=fontsize)
93 | plt.yticks(fontsize=fontsize)
94 |
95 | ax = plt.gca()
96 | ax.xaxis.set_major_locator(MaxNLocator(integer=True))
97 | ax.xaxis.set_major_formatter('{x:3<0.2f}')
98 | ax.yaxis.set_major_formatter('{x:3<0.2f}')
99 |
100 | # Show the plot
101 | plt.tight_layout()
102 | plt.savefig(f'./figs/{dataset}_distribution_shift_conformal_risk.jpg', dpi=800)
--------------------------------------------------------------------------------
/conformal_risk_computation/conformal_generation_risk.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | import numpy as np
6 | from matplotlib.ticker import MaxNLocator
7 | from tqdm import tqdm
8 | from scipy.stats import binom
9 | from typing import List, Union, Optional, Tuple, Mapping, Dict
10 | import os
11 | from numpy import linalg
12 | import numpy as np
13 | import argparse
14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func
15 |
16 | if __name__=='__main__':
17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.')
18 |
19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base'])
20 | parser.add_argument('--datasets', nargs='+')
21 | parser.add_argument('--num_gen', type=int, default=1)
22 | parser.add_argument('--gen_thresh', type=int, default=0)
23 | parser.add_argument('--n_rag_list', nargs='+')
24 |
25 | args = parser.parse_args()
26 |
27 | method2name = get_method2name()
28 | method2method = get_method2method()
29 |
30 | datasets = args.datasets
31 | method = args.llm_retrievalmodel
32 |
33 | num_point_dict, sample_size_dict = get_num_points()
34 |
35 |
36 | for dataset in datasets:
37 | num_simulation_point = num_point_dict[dataset]
38 | sample_size_per_point = sample_size_dict[dataset]
39 |
40 | num_gen = args.num_gen
41 | gen_thresh = args.gen_thresh
42 | n_rag_list = args.n_rag_list
43 |
44 | alphas = {}
45 | min_risk, max_risk = 10e5, -10e5
46 | alpha_list = []
47 |
48 | for n_rag in n_rag_list:
49 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{gen_thresh}_calibration_result.json'
50 | results = json.load(open(path, 'r', encoding='utf-8'))
51 | results = list(results.values())[0][0]
52 | results = np.array(results)
53 | alpha_list.append(empirical_risk2alpha_func(1.-np.mean(results), N_cal=len(results), delta=0.1))
54 | if alpha_list[-1]max_risk:
57 | max_risk = alpha_list[-1]
58 | alphas[method] = alpha_list
59 |
60 | data = {
61 | 'N_rag': n_rag_list,
62 | method: alphas[method],
63 | }
64 |
65 | df = pd.DataFrame(data)
66 |
67 | # Set the style
68 | plt.style.use('seaborn-v0_8-darkgrid')
69 | plt.figure(figsize=(10,8))
70 |
71 | color_dict = {method: 'black'}
72 |
73 | for metric, color in color_dict.items():
74 | plt.plot(df['N_rag'], df[metric], marker='^', color=color, label=r'Certified Conformal Risk $\alpha_{\text{rag}}$', markersize=20)
75 |
76 |
77 | simulation_points = []
78 | for n_rag in n_rag_list:
79 | results = json.load(open(f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{gen_thresh}_calibration_result.json', 'r', encoding='utf-8'))
80 | results = list(results.values())[0][0]
81 | results = np.array(results)
82 |
83 | simulation_points_temp = []
84 | for idx_point in range(num_simulation_point):
85 | test_list = random.sample(list(range(len(results))),sample_size_per_point)
86 | new_res = 1. - np.mean(results[test_list])
87 | simulation_points_temp.append(new_res)
88 | simulation_points_temp = np.array(simulation_points_temp)
89 | for x in simulation_points_temp:
90 | if x < min_risk:
91 | min_risk = x
92 | plt.scatter(x=[n_rag]*len(simulation_points_temp),y=simulation_points_temp,color='gray',alpha=0.7,s=100)
93 |
94 | fontsize=45
95 |
96 | plt.ylim([min_risk-0.001, max_risk+0.02])
97 | plt.xticks(fontsize=fontsize)
98 | plt.yticks(fontsize=fontsize)
99 |
100 | ax = plt.gca()
101 | ax.xaxis.set_major_locator(MaxNLocator(integer=True))
102 | ax.yaxis.set_major_formatter('{x:3<0.2f}')
103 |
104 | plt.tight_layout()
105 | print('save figure at {}'.format({f'./figs/{dataset}_{method2name[method]}_conformal_generation_risk.jpg'}))
106 | plt.savefig(f'./figs/{dataset}_{method2name[method]}_conformal_generation risk.jpg',dpi=800)
--------------------------------------------------------------------------------
/conformal_risk_computation/conformal_generation_risk_comparison.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | import numpy as np
6 | from matplotlib.ticker import MaxNLocator
7 | from tqdm import tqdm
8 | from scipy.stats import binom
9 | from typing import List, Union, Optional, Tuple, Mapping, Dict
10 | import os
11 | from numpy import linalg
12 | import numpy as np
13 | import argparse
14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func, get_color_dict
15 |
16 | if __name__=='__main__':
17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.')
18 |
19 | parser.add_argument('--llm_retrieval_methods', nargs='+')
20 | parser.add_argument('--datasets', nargs='+')
21 | parser.add_argument('--num_gen', type=int, default=1)
22 | parser.add_argument('--gen_thresh', type=int, default=0)
23 | parser.add_argument('--n_rag_list', nargs='+')
24 |
25 | args = parser.parse_args()
26 |
27 | method2name = get_method2name()
28 | method2method = get_method2method()
29 |
30 | datasets = args.datasets
31 | retrieval_methods = args.llm_retrieval_methods
32 |
33 | num_point_dict, sample_size_dict = get_num_points()
34 |
35 | for dataset in datasets:
36 | num_gen = args.num_gen
37 | gen_thresh = args.gen_thresh
38 | n_rag_list = args.n_rag_list
39 | n_rag_list = [int(item) for item in n_rag_list]
40 |
41 | alphas = {}
42 | min_risk, max_risk = 10e5, -10e5
43 |
44 | for method in retrieval_methods:
45 | alpha_list = []
46 | for n_rag in n_rag_list:
47 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{gen_thresh}_calibration_result.json'
48 | results = json.load(open(path, 'r', encoding='utf-8'))
49 | results = list(results.values())[0][0]
50 | results = np.array(results)
51 |
52 | alpha_list.append(empirical_risk2alpha_func(1. - np.mean(results), N_cal=len(results), delta=0.1))
53 | if alpha_list[-1]max_risk:
56 | max_risk = alpha_list[-1]
57 | alphas[method] = alpha_list
58 |
59 |
60 | data = {
61 | 'N_rag': n_rag_list,
62 | }
63 |
64 | for idx in range(len(retrieval_methods)):
65 | data[retrieval_methods[idx]] = alphas[retrieval_methods[idx]]
66 |
67 |
68 | df = pd.DataFrame(data)
69 |
70 | # Set the style
71 | plt.style.use('seaborn-v0_8-darkgrid')
72 | plt.figure(figsize=(10,8))
73 |
74 | color_dict = get_color_dict(retrieval_methods)
75 |
76 | for metric, color in color_dict.items():
77 | plt.plot(df['N_rag'], df[metric], marker='*', color=color, label=method2name[metric], markersize=25)
78 |
79 | fontsize=45
80 | plt.ylim([min_risk-0.02, max_risk+0.02])
81 | plt.xticks(fontsize=fontsize)
82 | plt.yticks(fontsize=fontsize)
83 |
84 | ax = plt.gca()
85 | ax.xaxis.set_major_locator(MaxNLocator(integer=True))
86 |
87 | plt.tight_layout()
88 | plt.savefig(f'./figs/{dataset}_conformal_risk.jpg',dpi=800)
--------------------------------------------------------------------------------
/conformal_risk_computation/conformal_generation_risk_multi_dim_config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | import numpy as np
6 | from matplotlib.ticker import MaxNLocator
7 | from tqdm import tqdm
8 | from scipy.stats import binom
9 | from typing import List, Union, Optional, Tuple, Mapping, Dict
10 | import os
11 | from numpy import linalg
12 | import numpy as np
13 | import argparse
14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func
15 |
16 | if __name__=='__main__':
17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.')
18 |
19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base'])
20 | parser.add_argument('--datasets', nargs='+')
21 | parser.add_argument('--gen_thresh', type=int, default=0)
22 | parser.add_argument('--n_rag_list', nargs='+')
23 | parser.add_argument('--lambda_g_list', nargs='+')
24 |
25 | args = parser.parse_args()
26 |
27 | method2name = get_method2name()
28 | method2method = get_method2method()
29 |
30 | datasets = args.datasets
31 | method = args.llm_retrievalmodel
32 |
33 | n_rag_list = np.array(args.n_rag_list)
34 | lambda_g_list = np.array(args.lambda_g_list)
35 |
36 | num_point_dict, sample_size_dict = get_num_points()
37 |
38 | for dataset in datasets:
39 |
40 | fig = plt.figure(figsize=(10, 8))
41 | ax = plt.axes(projection='3d')
42 |
43 | xdata, ydata, zdata = [], [], []
44 | c = []
45 | random.seed(1)
46 | rand_list = []
47 |
48 | num_list = []
49 | for idx in range(len(lambda_g_list)):
50 | num_list.append(int(lambda_g_list[idx]))
51 | max_num_gen = max(num_list)
52 |
53 |
54 | for n_rag in n_rag_list:
55 | for num_gen in lambda_g_list:
56 | num_gen = int(num_gen)
57 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_{args.gen_thresh}_calibration_result.json'
58 |
59 | result = json.load(open(path, 'r', encoding='utf-8'))
60 | result = list(result.values())[0]
61 | result = np.array(result)
62 | rand_indx = random.sample(tuple(list(range(0, max_num_gen))), num_gen)
63 |
64 | result = result[rand_indx, :]
65 | rand_list.append(rand_indx)
66 | result = result.max(axis=0)
67 | results = 1. - result
68 |
69 | alpha_rag = empirical_risk2alpha_func(np.mean(results), N_cal=len(results), delta=0.1)
70 |
71 | xdata.append(n_rag)
72 | ydata.append(num_gen)
73 | zdata.append(alpha_rag)
74 | c.append('red')
75 |
76 | simulation_points_temp = []
77 | num_simulation_point = num_point_dict[dataset]
78 | sample_size_per_point = sample_size_dict[dataset]
79 | for idx_point in range(num_simulation_point):
80 | test_list = random.sample(list(range(len(results))), sample_size_per_point)
81 | new_res = np.mean(results[test_list])
82 | simulation_points_temp.append(new_res)
83 | simulation_points_temp = np.array(simulation_points_temp)
84 |
85 | ax.scatter3D([n_rag] * len(simulation_points_temp), [num_gen] * len(simulation_points_temp), simulation_points_temp, c='gray', alpha=0.3, s=25)
86 |
87 | ax.scatter3D(xdata, ydata, zdata, c=c, marker='^', s=80)
88 |
89 | labelsize = 20
90 | fontsize = 30
91 | ax.tick_params(axis='x', labelsize=labelsize,pad=-3)
92 | ax.tick_params(axis='y', labelsize=labelsize,pad=-3)
93 | ax.tick_params(axis='z', labelsize=labelsize,pad=8)
94 | ax.set_xlabel(r'$N_{rag}$', fontsize=fontsize, labelpad=10.0)
95 | ax.set_ylabel(r'$\lambda_g$', fontsize=fontsize, labelpad=10.0)
96 | ax.set_zlabel(r'Risk', fontsize=fontsize, labelpad=25.0, rotation=90)
97 |
98 | plt.tight_layout()
99 | plt.savefig(f'./figs/{dataset}_{method2method[method]}_nrag_lambdag_conformal_risk.jpg',dpi=800)
--------------------------------------------------------------------------------
/conformal_risk_computation/conformal_generation_risk_num_gen.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | import numpy as np
6 | from matplotlib.ticker import MaxNLocator
7 | from tqdm import tqdm
8 | from scipy.stats import binom
9 | from typing import List, Union, Optional, Tuple, Mapping, Dict
10 | import os
11 | from numpy import linalg
12 | import numpy as np
13 | import argparse
14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func, compute_conformal_risk
15 |
16 | if __name__=='__main__':
17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.')
18 |
19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base'])
20 | parser.add_argument('--datasets', nargs='+')
21 | parser.add_argument('--n_rag', type=int, default=5)
22 | parser.add_argument('--gen_thresh', type=int, default=0)
23 | parser.add_argument('--num_gen_list', nargs='+')
24 |
25 | args = parser.parse_args()
26 |
27 | method2name = get_method2name()
28 | method2method = get_method2method()
29 |
30 | datasets = args.datasets
31 | method = args.llm_retrievalmodel
32 |
33 | num_point_dict, sample_size_dict = get_num_points()
34 |
35 | num_list = []
36 | for idx in range(len(args.num_gen_list)):
37 | num_list.append(int(args.num_gen_list[idx]))
38 | max_num_gen = max(num_list)
39 |
40 | random.seed(1)
41 |
42 | for dataset in datasets:
43 | num_simulation_point = num_point_dict[dataset]
44 | sample_size_per_point = sample_size_dict[dataset]
45 |
46 | num_gen_list = args.num_gen_list
47 | num_gen_list = [int(item) for item in num_gen_list]
48 | gen_thresh = args.gen_thresh
49 | n_rag = args.n_rag
50 |
51 | alphas = {}
52 | min_risk, max_risk = 10e5, -10e5
53 | rand_list = []
54 | alpha_list = []
55 | for num_gen in num_gen_list:
56 | num_gen = int(num_gen)
57 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_{gen_thresh}_calibration_result.json'
58 | result = json.load(open(path, 'r', encoding='utf-8'))
59 | result = list(result.values())[0]
60 | result = np.array(result)
61 | rand_indx = random.sample(tuple(list(range(0,max_num_gen))), num_gen)
62 | result = result[rand_indx,:]
63 | rand_list.append(rand_indx)
64 | result = result.max(axis=0)
65 | results = 1. - result
66 |
67 | alpha_list.append(compute_conformal_risk(results))
68 | if alpha_list[-1] < min_risk:
69 | min_risk = alpha_list[-1]
70 | if alpha_list[-1] > max_risk:
71 | max_risk = alpha_list[-1]
72 | alphas[method] = alpha_list
73 |
74 | data = {
75 | 'num_gen': num_gen_list,
76 | method: alphas[method],
77 | }
78 |
79 | df = pd.DataFrame(data)
80 |
81 | plt.style.use('seaborn-v0_8-darkgrid')
82 | plt.figure(figsize=(10, 8))
83 |
84 | color_dict = {method: 'black'}
85 |
86 | for metric, color in color_dict.items():
87 | plt.plot(df['num_gen'], df[metric], marker='^', color=color, label=r'Certified Conformal Risk $\alpha_{\text{rag}}$', markersize=20)
88 |
89 | simulation_points = []
90 |
91 | for num_gen in num_gen_list:
92 | num_gen = int(num_gen)
93 |
94 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_{gen_thresh}_calibration_result.json'
95 |
96 | result = json.load(open(path, 'r', encoding='utf-8'))
97 | result = list(result.values())[0]
98 | result = np.array(result)
99 | result = result[rand_list[num_gen-1], :]
100 | result = result.max(axis=0)
101 | results = 1. - result
102 |
103 | simulation_points_temp = []
104 | for idx_point in range(num_simulation_point):
105 | test_list = random.sample(list(range(len(results))), sample_size_per_point)
106 | res = np.mean(results[test_list])
107 | simulation_points_temp.append(res)
108 | simulation_points_temp = np.array(simulation_points_temp)
109 | for x in simulation_points_temp:
110 | if x < min_risk:
111 | min_risk = x
112 |
113 | plt.scatter(x=[num_gen] * len(simulation_points_temp), y=simulation_points_temp, color='gray', alpha=0.7, s=100)
114 |
115 | fontsize = 45
116 |
117 | plt.ylim([min_risk - 0.02, max_risk + 0.02])
118 | plt.xticks(fontsize=fontsize)
119 | plt.yticks(fontsize=fontsize)
120 |
121 | ax = plt.gca()
122 | # ax.xaxis.set_major_locator(MaxNLocator(integer=True))
123 | ax.yaxis.set_major_formatter('{x:3<0.2f}')
124 |
125 | plt.tight_layout()
126 | print('save figure at {}'.format({f'./figs/{dataset}_{method2name[method]}_multi_gen_conformal_bound_simulation.jpg'}))
127 | plt.savefig(f'./figs/{dataset}_{method2name[method]}_multi_gen_conformal_bound_simulation.jpg', dpi=800)
--------------------------------------------------------------------------------
/conformal_risk_computation/conformal_generation_risk_similarity_threshold.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | import numpy as np
6 | from matplotlib.ticker import MaxNLocator
7 | from tqdm import tqdm
8 | from scipy.stats import binom
9 | from typing import List, Union, Optional, Tuple, Mapping, Dict
10 | import os
11 | from numpy import linalg
12 | import numpy as np
13 | import argparse
14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func, compute_conformal_risk
15 |
16 | if __name__=='__main__':
17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.')
18 |
19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base'])
20 | parser.add_argument('--datasets', nargs='+')
21 | parser.add_argument('--num_gen', type=int, default=20)
22 | parser.add_argument('--n_rag', type=int, default=5)
23 | parser.add_argument('--thresh_list', nargs='+')
24 |
25 | args = parser.parse_args()
26 |
27 | method2name = get_method2name()
28 | method2method = get_method2method()
29 |
30 | datasets = args.datasets
31 | method = args.llm_retrievalmodel
32 |
33 | num_point_dict, sample_size_dict = get_num_points()
34 |
35 | random.seed(1)
36 |
37 | for dataset in datasets:
38 | num_simulation_point = num_point_dict[dataset]
39 | sample_size_per_point = sample_size_dict[dataset]
40 |
41 | thresh_list = args.thresh_list
42 | num_gen = args.num_gen
43 | n_rag = args.n_rag
44 |
45 | alphas = {}
46 | min_risk, max_risk = 10e5, -10e5
47 | rand_list = []
48 | alpha_list = []
49 | for idx, thresh in enumerate(thresh_list):
50 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{thresh}_calibration_result.json'
51 |
52 | result = json.load(open(path, 'r', encoding='utf-8'))
53 | result = list(result.values())[0]
54 | result = np.array(result)
55 | result = result.max(axis=0)
56 | results = 1. - result
57 |
58 | alpha_list.append(compute_conformal_risk(results))
59 | if alpha_list[-1] < min_risk:
60 | min_risk = alpha_list[-1]
61 | if alpha_list[-1] > max_risk:
62 | max_risk = alpha_list[-1]
63 | alphas[method] = alpha_list
64 |
65 | data = {
66 | 'thresh': [ int(item) / 100. for item in thresh_list],
67 | method: alphas[method],
68 | }
69 |
70 | df = pd.DataFrame(data)
71 | plt.style.use('seaborn-v0_8-darkgrid')
72 | plt.figure(figsize=(10, 8))
73 |
74 | color_dict = {method: 'black'}
75 |
76 | for metric, color in color_dict.items():
77 | plt.plot(df['thresh'], df[metric], marker='^', color=color, label=r'Certified Conformal Risk $\alpha_{\text{rag}}$', markersize=20)
78 |
79 | simulation_points = []
80 | for idx_,thresh in enumerate(thresh_list):
81 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{num_gen}_{thresh}_calibration_result.json'
82 | result = json.load(open(path, 'r', encoding='utf-8'))
83 | result = list(result.values())[0]
84 | result = np.array(result)
85 | result = result.max(axis=0)
86 | results = 1. - result
87 |
88 | simulation_points_temp = []
89 | for idx_point in range(num_simulation_point):
90 | test_list = random.sample(list(range(len(results))), sample_size_per_point)
91 | res = np.mean(results[test_list])
92 | simulation_points_temp.append(res)
93 | simulation_points_temp = np.array(simulation_points_temp)
94 | for x in simulation_points_temp:
95 | if x < min_risk:
96 | min_risk = x
97 | plt.scatter(x=[int(thresh) / 100.0] * len(simulation_points_temp), y=simulation_points_temp, color='gray', alpha=0.7, s=100)
98 |
99 | fontsize = 32
100 | plt.ylim([min_risk - 0.02, max_risk + 0.02])
101 | plt.xticks(fontsize=fontsize)
102 | plt.yticks(fontsize=fontsize)
103 |
104 | ax = plt.gca()
105 | # ax.xaxis.set_major_locator(MaxNLocator(integer=True))
106 | ax.yaxis.set_major_formatter('{x:3<0.2f}')
107 |
108 | plt.gca().invert_xaxis()
109 | plt.tight_layout()
110 | print('save figure at {}'.format({f'./figs/{dataset}_{method2name[method]}_multi_thresh_conformal_bound_simulation.jpg'}))
111 | plt.savefig(f'./figs/{dataset}_{method2name[method]}_multi_thresh_conformal_bound_simulation.jpg', dpi=800)
--------------------------------------------------------------------------------
/conformal_risk_computation/conformal_generation_risk_valid_config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | import numpy as np
6 | from matplotlib.ticker import MaxNLocator
7 | from tqdm import tqdm
8 | from scipy.stats import binom
9 | from typing import List, Union, Optional, Tuple, Mapping, Dict
10 | import os
11 | from numpy import linalg
12 | import numpy as np
13 | import argparse
14 | from utils import save_json_to_file, get_method2name, get_method2method, get_num_points, empirical_risk2alpha_func, compute_p_value, FWER
15 |
16 | if __name__=='__main__':
17 | parser = argparse.ArgumentParser(description='Compute and plot the conformal generation risks of one retrieval model.')
18 |
19 | parser.add_argument('--llm_retrievalmodel', type=str, choices=['llama-7b_BM25', 'llama-7b_BAAI', 'llama-7b_OpenAI', 'llama-7b_triever-base'])
20 | parser.add_argument('--datasets', nargs='+')
21 | parser.add_argument('--gen_thresh', type=int, default=0)
22 | parser.add_argument('--n_rag_list', nargs='+')
23 | parser.add_argument('--lambda_g_list', nargs='+')
24 | parser.add_argument('--alpha_desired', type=float, default=0.6)
25 |
26 | args = parser.parse_args()
27 |
28 | method2name = get_method2name()
29 | method2method = get_method2method()
30 |
31 | datasets = args.datasets
32 | method = args.llm_retrievalmodel
33 |
34 | n_rag_list = np.array(args.n_rag_list)
35 | lambda_g_list = np.array(args.lambda_g_list)
36 |
37 | alpha_desired = args.alpha_desired
38 | delta = 0.1
39 |
40 | num_point_dict, sample_size_dict = get_num_points()
41 |
42 | num_list = []
43 | for idx in range(len(lambda_g_list)):
44 | num_list.append(int(lambda_g_list[idx]))
45 | max_num_gen = max(num_list)
46 |
47 | for dataset in datasets:
48 | fig = plt.figure(figsize=(10, 8))
49 | ax = plt.axes(projection='3d')
50 |
51 | xdata = []
52 | ydata = []
53 | zdata = []
54 | c = []
55 | random.seed(1)
56 | rand_list = []
57 | p_values = {}
58 | for n_rag in n_rag_list:
59 | for num_gen in lambda_g_list:
60 | num_gen = int(num_gen)
61 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_0_calibration_result.json'
62 |
63 | result = json.load(open(path, 'r', encoding='utf-8'))
64 | result = list(result.values())[0]
65 | result = np.array(result)
66 | rand_indx = random.sample(tuple(list(range(0, max_num_gen))), num_gen)
67 | result = result[rand_indx, :]
68 | rand_list.append(rand_indx)
69 | result = result.max(axis=0)
70 | results = 1. - result
71 |
72 | p_ = compute_p_value(alpha_desired, np.mean(results), len(results))
73 | p_values[(n_rag, num_gen)] = p_
74 |
75 | valid_or_not = FWER(p_values, delta)
76 |
77 | for n_rag in n_rag_list:
78 | for num_gen in lambda_g_list:
79 | num_gen = int(num_gen)
80 | if (not valid_or_not[(n_rag, num_gen)]):
81 | continue
82 |
83 | ax.scatter3D([n_rag], [num_gen], 0.01, c='green', alpha=0.9, s=200)
84 |
85 | path = f'outputs/{method2method[method]}/{method}/{dataset}_{n_rag}_{max_num_gen}_0_calibration_result.json'
86 |
87 | result = json.load(open(path, 'r', encoding='utf-8'))
88 | result = list(result.values())[0]
89 | result = np.array(result)
90 | rand_indx = random.sample(tuple(list(range(0, max_num_gen))), num_gen)
91 | result = result[rand_indx, :]
92 | rand_list.append(rand_indx)
93 | result = result.max(axis=0)
94 | results = 1. - result
95 |
96 | simulation_points_temp = []
97 | num_simulation_point = num_point_dict[dataset]
98 | sample_size_per_point = sample_size_dict[dataset]
99 | for idx_point in range(num_simulation_point):
100 | test_list = random.sample(list(range(len(results))), sample_size_per_point)
101 | new_res = np.mean(results[test_list])
102 | simulation_points_temp.append(new_res)
103 | simulation_points_temp = np.array(simulation_points_temp)
104 | ax.scatter3D([n_rag] * len(simulation_points_temp), [num_gen] * len(simulation_points_temp), simulation_points_temp, c='gray', alpha=0.3, s=25)
105 |
106 |
107 | n_rag_list = [int(item) for item in n_rag_list]
108 | lambda_g_list = [int(item) for item in lambda_g_list]
109 |
110 | N_rag_grid, lambda_g_grid = np.meshgrid(n_rag_list, lambda_g_list)
111 | risk_values = np.ones_like(N_rag_grid, dtype=float) * alpha_desired
112 |
113 |
114 |
115 | surf = ax.plot_surface(N_rag_grid, lambda_g_grid, risk_values, color='red', edgecolor='none', alpha=0.2)
116 |
117 | labelsize = 20
118 | fontsize = 30
119 | ax.tick_params(axis='x', labelsize=labelsize,pad=-3)
120 | ax.tick_params(axis='y', labelsize=labelsize,pad=-3)
121 | ax.tick_params(axis='z', labelsize=labelsize,pad=8)
122 | ax.set_xlabel(r'$N_{rag}$', fontsize=fontsize, labelpad=10.0)
123 | ax.set_ylabel(r'$\lambda_g$', fontsize=fontsize, labelpad=10.0)
124 | ax.set_zlabel(r'Risk', fontsize=fontsize, labelpad=25.0, rotation=90)
125 |
126 | plt.tight_layout()
127 | print(f'save fig at ./figs/{dataset}_{method2name[method]}_valid_config.jpg')
128 | plt.savefig(f'./figs/{dataset}_{method2name[method]}_valid_config.jpg',dpi=800)
--------------------------------------------------------------------------------
/conformal_risk_computation/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import pandas as pd
3 | import numpy as np
4 | from matplotlib.ticker import MaxNLocator
5 | from tqdm import tqdm
6 | from scipy.stats import binom
7 | from typing import List, Union, Optional, Tuple, Mapping, Dict
8 | import os
9 | import json
10 | from scipy.optimize import fsolve
11 |
12 | def save_json_to_file(objects: Union[List, dict], path: str, line_by_line: bool = False):
13 | if line_by_line:
14 | assert isinstance(objects, list), 'Only list can be saved in line by line format'
15 |
16 | os.makedirs(os.path.dirname(path), exist_ok=True)
17 | with open(path, 'w', encoding='utf-8') as writer:
18 | if not line_by_line:
19 | json.dump(objects, writer, ensure_ascii=False, indent=4, separators=(',', ':'))
20 | else:
21 | for obj in objects:
22 | writer.write(json.dumps(obj, ensure_ascii=False, separators=(',', ':')))
23 | writer.write('\n')
24 |
25 | def get_method2name():
26 | method2name = {}
27 | method2name['llama-7b_BM25'] = 'BM25'
28 | method2name['llama-7b_triever-base'] = 'LLM-R'
29 | method2name['llama-7b_OpenAI'] = 'openai'
30 | method2name['llama-7b_BAAI'] = 'baai'
31 | return method2name
32 |
33 |
34 | def get_method2method():
35 | method2method = {}
36 | method2method['llama-7b_triever-base'] = 'llm-retriever-base'
37 | method2method['llama-7b_BM25'] = 'bm25'
38 | method2method['llama-7b_OpenAI'] = 'openai'
39 | method2method['llama-7b_BAAI'] = 'baai'
40 | return method2method
41 |
42 | def get_num_points():
43 | num_point_dict = {'aeslc': 100, 'common_gen': 100, 'dart': 100, 'e2e_nlg': 100}
44 | sample_size_dict = {'aeslc': 400, 'common_gen': 400, 'dart': 400, 'e2e_nlg': 400}
45 | return num_point_dict, sample_size_dict
46 |
47 | def get_color_dict(retrieval_methods):
48 | colors = {'llama-7b_BM25': 'royalblue', 'llama-7b_triever-base': 'lightcoral', 'llama-7b_OpenAI': 'darkviolet', 'llama-7b_BAAI': 'olivedrab'}
49 | color_dict = {}
50 | for method in retrieval_methods:
51 | color_dict[method] = colors[method]
52 | return color_dict
53 |
54 | def get_max_x_all():
55 | max_x_all = {'aeslc': 11, 'common_gen': 21, 'dart': 21, 'e2e_nlg': 21}
56 | return max_x_all
57 |
58 | def get_num_points_distribution_shift():
59 | num_point_dict = {'aeslc': 100, 'common_gen': 100, 'dart': 100, 'e2e_nlg': 100}
60 | sample_size_dict = {'aeslc': 30, 'common_gen': 30, 'dart': 30, 'e2e_nlg': 30}
61 | simulate_num = 30000
62 | return num_point_dict, sample_size_dict, simulate_num
63 |
64 | def h1(a,b):
65 | return a * np.log(a/b) + (1-a) * np.log((1-a)/(1-b))
66 |
67 | def solve_inverse_1_func(r_hat, N_cal, delta):
68 | def func(a):
69 | ret = np.exp(-N_cal * h1(min(a,r_hat), a)) - delta
70 | return ret
71 | root = fsolve(func, [r_hat+0.02])[0]
72 | return root
73 |
74 | def solve_inverse_2_func(r_hat, N_cal, delta):
75 | def func(a):
76 | ret = binom.cdf(np.ceil(N_cal * r_hat), N_cal, a) - delta / np.exp(1)
77 | return ret
78 | root = fsolve(func, [r_hat+0.02])[0]
79 | return root
80 |
81 | def empirical_risk2alpha_func(r_hat, N_cal, delta):
82 | delta = delta
83 | alpha_term_1 = solve_inverse_1_func(r_hat, N_cal, delta)
84 | alpha_term_2 = solve_inverse_2_func(r_hat, N_cal, delta)
85 | alpha = min(alpha_term_1, alpha_term_2)
86 | if alpha>1.0:
87 | alpha=1.0
88 | return alpha
89 |
90 | def cal_dist_shift_bound(r_hat, dist, N, var, delta):
91 | r_hat_overline = r_hat + dist**2 * (2-dist**2) * (1-r_hat) + 2 * dist * (1-dist**2) * np.sqrt(2-dist**2) * np.sqrt(var)
92 | # print(f'r_hat_overline 1: {r_hat_overline}')
93 |
94 | # finite-sample error
95 | r_hat_overline += (1-dist**2) * ((1-dist**2) / np.sqrt(2*N) + 2 * dist * np.sqrt(2 * (2 - dist**2)) / np.sqrt(N-1)) * np.sqrt(np.log(4/delta)) + np.sqrt(np.log(8/delta) / 2 / N)
96 |
97 | # print(f'r_hat_overline 2: {r_hat_overline}')
98 | # risk_shift = empirical_risk2alpha_newton(r_hat_overline, N, delta)
99 | risk_shift = empirical_risk2alpha_func(r_hat_overline, N, delta)
100 | return risk_shift
101 |
102 | def compute_hellinger_distance(vec1, vec2):
103 | dist = 1.
104 | for idx in range(len(vec1)):
105 | dist -= np.sqrt(vec1[idx]) * np.sqrt(vec2[idx])
106 | dist = np.sqrt(dist)
107 | return dist
108 |
109 | def compute_p_value(alpha_desired, r_hat, N_cal):
110 | term1 = np.exp(-N_cal * h1(min(alpha_desired,r_hat), alpha_desired))
111 | term2 = binom.cdf(np.ceil(N_cal * r_hat), N_cal, alpha_desired) * np.exp(1)
112 | p = min(term1, term2)
113 | if r_hat > alpha_desired:
114 | return 1.0
115 | return p
116 |
117 | def FWER(p_values, delta):
118 | keys = p_values.keys()
119 | n_tol = len(keys)
120 | valid_or_not = {}
121 | # print(p_values)
122 | for key in p_values.keys():
123 | if p_values[key] < delta / n_tol:
124 | valid_or_not[key] = True
125 | else:
126 | valid_or_not[key] = False
127 | return valid_or_not
128 |
129 | def compute_conformal_risk(results):
130 | return empirical_risk2alpha_func(np.mean(results), len(results), delta=0.1)
--------------------------------------------------------------------------------
/ds_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 10,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 |
11 | "optimizer": {
12 | "type": "AdamW",
13 | "params": {
14 | "lr": "auto",
15 | "betas": "auto",
16 | "eps": "auto",
17 | "weight_decay": "auto"
18 | }
19 | },
20 |
21 | "scheduler": {
22 | "type": "WarmupDecayLR",
23 | "params": {
24 | "warmup_min_lr": "auto",
25 | "warmup_max_lr": "auto",
26 | "warmup_num_steps": "auto",
27 | "total_num_steps": "auto"
28 | }
29 | },
30 |
31 | "zero_optimization": {
32 | "stage": 2,
33 | "allgather_partitions": true,
34 | "allgather_bucket_size": 2e8,
35 | "overlap_comm": true,
36 | "reduce_scatter": true,
37 | "reduce_bucket_size": 2e8,
38 | "contiguous_gradients": true
39 | },
40 |
41 | "gradient_accumulation_steps": "auto",
42 | "gradient_clipping": "auto",
43 | "steps_per_print": 2000,
44 | "train_batch_size": "auto",
45 | "train_micro_batch_size_per_gpu": "auto",
46 | "wall_clock_breakdown": false
47 | }
48 |
--------------------------------------------------------------------------------
/figs/README.md:
--------------------------------------------------------------------------------
1 | The generated images will be stored in this directory.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.28
2 | datasets==2.3.0
3 | deepspeed==0.8.3
4 | tqdm
5 | rouge
6 | faiss-cpu
7 | matplotlib
8 | pandas
9 | scipy
10 | gdown
11 | rank_bm25
12 | tiktoken
13 | openai
14 | scikit-learn
15 | absl-py
--------------------------------------------------------------------------------
/run/conformal_parallel.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 | set -e
5 |
6 | DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )"
7 | echo "working directory: ${DIR}"
8 |
9 | MODEL_NAME_OR_PATH="random"
10 | if [[ $# -ge 1 && ! "$1" == "--"* ]]; then
11 | MODEL_NAME_OR_PATH=$1
12 | shift
13 | fi
14 |
15 | LLM_MODEL_NAME_OR_PATH="huggyllama/llama-7b"
16 | if [[ $# -ge 1 && ! "$1" == "--"* ]]; then
17 | LLM_MODEL_NAME_OR_PATH=$1
18 | shift
19 | fi
20 |
21 | if [ -z "$OUTPUT_DIR" ]; then
22 | OUTPUT_DIR="${MODEL_NAME_OR_PATH}"
23 | fi
24 | if [ -z "$DATA_DIR" ]; then
25 | DATA_DIR="${DIR}/data/tasks/"
26 | fi
27 |
28 |
29 | EVAL_TASKS=('aeslc' 'common_gen' 'dart' 'e2e_nlg')
30 |
31 | # generate data: input prompt (retrieved examples) and query for the specified task
32 | PYTHONPATH=src/ python -u src/inference/generate_few_shot_prompt.py \
33 | --model_name_or_path "${MODEL_NAME_OR_PATH}" \
34 | --seed 1234 \
35 | --fp16 \
36 | --llm_eval_tasks "${EVAL_TASKS[@]}" \
37 | --llm_eval_split test \
38 | --llm_k_shot "${N_rag}" \
39 | --output_dir "${OUTPUT_DIR}" \
40 | --data_dir "${DATA_DIR}"
41 |
42 |
43 | nproc=$(echo -n "$GPUs" | wc -m)
44 | nproc=`expr $nproc + 1`
45 | nproc=`expr $nproc / 2`
46 |
47 | # evaluate the empirical risks
48 | CUDA_VISIBLE_DEVICES="${GPUs}" torchrun --nproc_per_node "${nproc}" --master_port=38765 src/conformal_calibration_empirical_risk.py \
49 | --llm_batch_size_per_device 4 \
50 | --llm_k_shot "${N_rag}" \
51 | --llm_num_gen "${llm_num_gen}" \
52 | --llm_gen_threshold "${llm_gen_threshold}" \
53 | --cali_ratio 0.5 \
54 | --model_name_or_path "${MODEL_NAME_OR_PATH}" \
55 | --seed 1234 \
56 | --fp16 \
57 | --do_llm_eval \
58 | --llm_model_name_or_path "${LLM_MODEL_NAME_OR_PATH}" \
59 | --llm_eval_tasks "${EVAL_TASKS[@]}" \
60 | --llm_eval_split test \
61 | --llm_max_input_length 1024 \
62 | --llm_max_decode_length 64 \
63 | --output_dir "${OUTPUT_DIR}" \
64 | --data_dir "${DATA_DIR}" \
65 | --overwrite_output_dir \
66 | --disable_tqdm True \
67 | --report_to none "$@"
--------------------------------------------------------------------------------
/scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk.sh:
--------------------------------------------------------------------------------
1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base
2 | do
3 | python conformal_risk_computation/conformal_distribution_shift_generation_risk.py --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg --n_rag 15 --num_gen 1 --gen_thresh 0
4 | done
--------------------------------------------------------------------------------
/scripts_conformal_generation_risk/run_conformal_distribution_shift_generation_risk_comparisons.sh:
--------------------------------------------------------------------------------
1 | python conformal_risk_computation/conformal_distribution_shift_generation_risk_comparison.py \
2 | --llm_retrieval_methods llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base \
3 | --n_rag 15 --num_gen 1 --gen_thresh 0 --datasets aeslc common_gen dart e2e_nlg
--------------------------------------------------------------------------------
/scripts_conformal_generation_risk/run_conformal_generation_risk.sh:
--------------------------------------------------------------------------------
1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base
2 | do
3 | python conformal_risk_computation/conformal_generation_risk.py --llm_retrievalmodel $method \
4 | --datasets aeslc common_gen dart e2e_nlg \
5 | --num_gen 1 --gen_thresh 0 --n_rag_list 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
6 | done
--------------------------------------------------------------------------------
/scripts_conformal_generation_risk/run_conformal_generation_risk_comparisons.sh:
--------------------------------------------------------------------------------
1 | python conformal_risk_computation/conformal_generation_risk_comparison.py \
2 | --llm_retrieval_methods llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base --datasets aeslc common_gen dart e2e_nlg \
3 | --num_gen 1 --gen_thresh 0 --n_rag_list 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
--------------------------------------------------------------------------------
/scripts_conformal_generation_risk/run_conformal_generation_risk_multi_dim_config.sh:
--------------------------------------------------------------------------------
1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base
2 | do
3 | python conformal_risk_computation/conformal_generation_risk_multi_dim_config.py \
4 | --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg --n_rag_list 0 1 2 3 4 5 6 7 8 9 \
5 | --lambda_g_list 1 2 3 4 5 6 7 8 9 10 --gen_thresh 0
6 | done
--------------------------------------------------------------------------------
/scripts_conformal_generation_risk/run_conformal_generation_risk_num_gen.sh:
--------------------------------------------------------------------------------
1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base
2 | do
3 | python conformal_risk_computation/conformal_generation_risk_num_gen.py \
4 | --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg \
5 | --n_rag 5 --gen_thresh 0 --num_gen_list 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
6 | done
--------------------------------------------------------------------------------
/scripts_conformal_generation_risk/run_conformal_generation_risk_similarity_threshold.sh:
--------------------------------------------------------------------------------
1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base
2 | do
3 | python conformal_risk_computation/conformal_generation_risk_similarity_threshold.py \
4 | --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg --n_rag 5 --num_gen 20 --thresh_list 10 9 8 7 6
5 | done
--------------------------------------------------------------------------------
/scripts_conformal_generation_risk/run_conformal_generation_risk_valid_config.sh:
--------------------------------------------------------------------------------
1 | for method in llama-7b_BM25 llama-7b_BAAI llama-7b_OpenAI llama-7b_triever-base
2 | do
3 | python conformal_risk_computation/conformal_generation_risk_valid_config.py \
4 | --llm_retrievalmodel $method --datasets aeslc common_gen dart e2e_nlg --n_rag_list 0 1 2 3 4 5 6 7 8 9 --lambda_g_list 1 2 3 4 5 6 7 8 9 10
5 | done
--------------------------------------------------------------------------------
/scripts_raw_risk_scores/baai.sh:
--------------------------------------------------------------------------------
1 | # OUTPUT_DIR: results stored at the dir
2 | # N_rag: number of retrieved examples
3 | # llm_num_gen: number of generations
4 | # llm_gen_threshold: accept similarity threshold during generations (Note: it is scaled by 100 and represented by an integer)
5 |
6 | for n in 0 1 2 3 4 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
7 | do
8 | for ngen in 1
9 | do
10 | GPUs=0,1,2,3 OUTPUT_DIR=outputs/baai N_rag=$n llm_num_gen=$ngen llm_gen_threshold=0 bash run/conformal_parallel.sh BAAI
11 | done
12 | done
--------------------------------------------------------------------------------
/scripts_raw_risk_scores/bm25.sh:
--------------------------------------------------------------------------------
1 | # OUTPUT_DIR: results stored at the dir
2 | # N_rag: number of retrieved examples
3 | # llm_num_gen: number of generations
4 | # llm_gen_threshold: accept similarity threshold during generations (Note: it is scaled by 100 and represented by an integer)
5 |
6 | for n in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
7 | do
8 | for ngen in 1
9 | do
10 | GPUs=0,1,2,3 OUTPUT_DIR=outputs/bm25 N_rag=$n llm_num_gen=$ngen llm_gen_threshold=0 bash run/conformal_parallel.sh BM25
11 | done
12 | done
--------------------------------------------------------------------------------
/scripts_raw_risk_scores/llm-r.sh:
--------------------------------------------------------------------------------
1 | # OUTPUT_DIR: results stored at the dir
2 | # N_rag: number of retrieved examples
3 | # llm_num_gen: number of generations
4 | # llm_gen_threshold: accept similarity threshold during generations (Note: it is scaled by 100 and represented by an integer)
5 |
6 | # path_retrieval_model specifies path of locally trained retrieval model
7 | # Huggingface transformer downloads or loads from cache from ~/.cache/huggingface/hub/models--llm_dir--llm_name (including blobs,refs,snapshots)
8 |
9 | for n in 0 1 2 3 # 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
10 | do
11 | for ngen in 1
12 | do
13 | GPUs=0,1,2,3 OUTPUT_DIR=outputs/llm-retriever-base N_rag=$n llm_num_gen=$ngen llm_gen_threshold=0 bash run/conformal_parallel.sh trained_retrieval_models/LLM-R/llm-retriever-base
14 | done
15 | done
--------------------------------------------------------------------------------
/scripts_raw_risk_scores/openai.sh:
--------------------------------------------------------------------------------
1 | # OUTPUT_DIR: results stored at the dir
2 | # N_rag: number of retrieved examples
3 | # llm_num_gen: number of generations
4 | # llm_gen_threshold: accept similarity threshold during generations (Note: it is scaled by 100 and represented by an integer)
5 |
6 | for n in 0 1 2 3 4 6 7 8 9 # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
7 | do
8 | for ngen in 1
9 | do
10 | GPUs=0,1,2,3 OUTPUT_DIR=outputs/openai N_rag=$n llm_num_gen=$ngen llm_gen_threshold=0 bash run/conformal_parallel.sh OpenAI
11 | done
12 | done
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kangmintong/C-RAG/5bda4382bebf6453b3b8dccb8a52f4d015e95fee/src/.DS_Store
--------------------------------------------------------------------------------
/src/collators/__init__.py:
--------------------------------------------------------------------------------
1 | from .biencoder_collator import BiencoderCollator
2 | from .cross_encoder_collator import CrossEncoderCollator
3 |
--------------------------------------------------------------------------------
/src/collators/biencoder_collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Dict, Any, Union, Optional
5 | from transformers import BatchEncoding, PreTrainedTokenizerBase
6 | from transformers.file_utils import PaddingStrategy
7 |
8 | from config import Arguments
9 |
10 |
11 | @dataclass
12 | class BiencoderCollator:
13 |
14 | args: Arguments
15 | tokenizer: PreTrainedTokenizerBase
16 | padding: Union[bool, str, PaddingStrategy] = True
17 | max_length: Optional[int] = None
18 | pad_to_multiple_of: Optional[int] = None
19 | return_tensors: str = "pt"
20 |
21 | def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
22 | queries: List[str] = [f['query'] for f in features]
23 | passages: List[str] = sum([f['passages'] for f in features], [])
24 |
25 | input_texts = queries + passages
26 |
27 | merged_batch_dict = self.tokenizer(
28 | input_texts,
29 | max_length=self.args.max_len,
30 | truncation=True,
31 | padding=self.padding,
32 | return_token_type_ids=False,
33 | pad_to_multiple_of=self.pad_to_multiple_of,
34 | return_tensors=self.return_tensors)
35 |
36 | # dummy placeholder for field "labels", won't use it to compute loss
37 | labels = torch.zeros(len(queries), dtype=torch.long)
38 | merged_batch_dict['labels'] = labels
39 |
40 | if 'kd_labels' in features[0]:
41 | kd_labels = torch.stack([torch.tensor(f['kd_labels']) for f in features], dim=0).float()
42 | merged_batch_dict['kd_labels'] = kd_labels
43 | return merged_batch_dict
44 |
--------------------------------------------------------------------------------
/src/collators/cross_encoder_collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Dict, Any
5 | from transformers import BatchEncoding, DataCollatorWithPadding
6 |
7 |
8 | @dataclass
9 | class CrossEncoderCollator(DataCollatorWithPadding):
10 |
11 | def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
12 | unpack_features = []
13 | for ex in features:
14 | keys = list(ex.keys())
15 | for idx in range(len(ex[keys[0]])):
16 | unpack_features.append({k: ex[k][idx] for k in keys})
17 |
18 | collated_batch_dict = self.tokenizer.pad(
19 | unpack_features,
20 | padding=self.padding,
21 | pad_to_multiple_of=self.pad_to_multiple_of,
22 | return_tensors=self.return_tensors)
23 |
24 | collated_batch_dict['labels'] = torch.zeros(len(features), dtype=torch.long)
25 | return collated_batch_dict
26 |
--------------------------------------------------------------------------------
/src/collators/gpt2_collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Dict, Any, Union, Optional
5 | from transformers import BatchEncoding, PreTrainedTokenizerBase
6 | from transformers.file_utils import PaddingStrategy
7 |
8 | from logger_config import logger
9 |
10 |
11 | @dataclass
12 | class ScoreCollator:
13 | tokenizer: PreTrainedTokenizerBase
14 | padding: Union[bool, str, PaddingStrategy] = True
15 | max_length: Optional[int] = None
16 | pad_to_multiple_of: Optional[int] = None
17 | return_tensors: str = "pt"
18 | delimiter: str = '\n'
19 |
20 | def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
21 | self.tokenizer.padding_side = 'right'
22 | input_texts = [f['input_texts'] for f in features]
23 | output_texts = [f['output_texts'] for f in features]
24 | assert all(not text.endswith(self.delimiter) for text in input_texts)
25 | assert all(not text.startswith(self.delimiter) for text in output_texts)
26 | concat_texts: List[str] = [self.delimiter.join([inp, out]) for inp, out in zip(input_texts, output_texts)]
27 |
28 | batch_dict = self.tokenizer(
29 | concat_texts,
30 | max_length=self.max_length,
31 | truncation=True,
32 | padding=self.padding,
33 | pad_to_multiple_of=self.pad_to_multiple_of,
34 | return_tensors=self.return_tensors)
35 |
36 | labels = batch_dict['input_ids'].clone()
37 | if self.tokenizer.pad_token_id is not None:
38 | labels[labels == self.tokenizer.pad_token_id] = -100
39 | num_valid_tokens = torch.cumsum(batch_dict['attention_mask'], dim=1)
40 | output_lengths: torch.LongTensor = torch.LongTensor(self._get_output_lengths(output_texts))
41 | logger.debug('output lengths: {}'.format(output_lengths))
42 | input_lengths: torch.LongTensor = torch.sum(batch_dict['attention_mask'], dim=1) - output_lengths
43 | labels[num_valid_tokens <= input_lengths[:, None]] = -100
44 | batch_dict['labels'] = labels
45 |
46 | return batch_dict
47 |
48 | def _get_output_lengths(self, output_texts: List[str]) -> List[int]:
49 | output_ids: List[List[int]] = self.tokenizer(
50 | output_texts, max_length=self.max_length, truncation=True, padding=False
51 | )['input_ids']
52 |
53 | for idx in range(len(output_ids)):
54 | # llama tokenizer prepend a bos token
55 | if output_ids[idx][0] == self.tokenizer.bos_token_id:
56 | output_ids[idx] = output_ids[idx][1:]
57 |
58 | lengths: List[int] = [len(output_id) for output_id in output_ids]
59 | assert all(length > 0 for length in lengths), lengths
60 |
61 | return lengths
62 |
63 |
64 | @dataclass
65 | class DecodeCollator:
66 | tokenizer: PreTrainedTokenizerBase
67 | padding: Union[bool, str, PaddingStrategy] = True
68 | max_length: Optional[int] = None
69 | pad_to_multiple_of: Optional[int] = None
70 | return_tensors: str = "pt"
71 |
72 | def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
73 | # batch_score requires right padding, but generate requires left padding
74 | self.tokenizer.padding_side = 'left'
75 | input_texts = [f['input_texts'] for f in features]
76 |
77 | batch_dict = self.tokenizer(
78 | input_texts,
79 | max_length=self.max_length,
80 | truncation=True,
81 | padding=self.padding,
82 | pad_to_multiple_of=self.pad_to_multiple_of,
83 | return_tensors=self.return_tensors)
84 | return batch_dict
85 |
--------------------------------------------------------------------------------
/src/compute_controlled_risk.py:
--------------------------------------------------------------------------------
1 | import scipy
2 | import numpy as np
3 |
4 | def binomcdf(p,n):
5 | p = 0.5
6 | n = 100
7 | x = 0
8 | for a in range(10):
9 | print(scipy.stats.binom.cdf(x, n, p))
10 | x += 10
11 |
12 | delta = 0.1
13 | e = 2.718
14 |
15 | target_prob = delta / e
16 |
17 | N= 1750
18 | Score_list = [0.0576 ,0.1236, 0.2412, 0.254, 0.2643, 0.2679, 0.2596, 0.272, 0.2721, 0.2729, 0.2702, 0.2732, 0.272]
19 | R_hat_list = np.array([1.]*len(Score_list)) - np.array(Score_list)
20 | alpha_list = []
21 | for R_hat in R_hat_list:
22 | t = np.ceil(N*R_hat)
23 | res = 0.
24 | res_gap = 1e10
25 | for ind in range(1000):
26 | alpha = ind * 0.001
27 | cdf = scipy.stats.binom.cdf(t, N, alpha)
28 | if abs(cdf-target_prob) Dataset:
14 | assert path.endswith('.jsonl') or path.endswith('.jsonl.gz')
15 |
16 | # two fields: id, contents
17 | corpus = load_dataset('json', data_files=path)['train']
18 | logger.info('Load {} documents from {} with columns {}'.format(len(corpus), path, corpus.column_names))
19 | logger.info('A random document: {}'.format(random.choice(corpus)))
20 | return corpus
21 |
22 |
23 | def to_positive_negative_format(example: Dict, topk_as_positive: int = 1, bottomk_as_negative: int = -1) -> Dict:
24 | # query_id / query / answers / doc_ids / doc_scores
25 | assert len(example['doc_ids']) == len(example['doc_scores'])
26 | sorted_indices: List[int] = np.argsort(example['doc_scores'])[::-1]
27 | positive_indices: List[int] = sorted_indices[:topk_as_positive]
28 | negative_indices: List[int] = sorted_indices[topk_as_positive:] if bottomk_as_negative <= 0 else sorted_indices[-bottomk_as_negative:]
29 | negative_indices = [idx for idx in negative_indices if idx not in positive_indices]
30 | np.random.shuffle(positive_indices)
31 | np.random.shuffle(negative_indices)
32 | return {
33 | 'positives': {
34 | 'doc_id': [example['doc_ids'][idx] for idx in positive_indices],
35 | 'score': [example['doc_scores'][idx] for idx in positive_indices],
36 | },
37 | 'negatives': {
38 | 'doc_id': [example['doc_ids'][idx] for idx in negative_indices],
39 | 'score': [example['doc_scores'][idx] for idx in negative_indices],
40 | },
41 | }
42 |
43 |
44 | def save_to_readable_format(in_path: str, corpus: Dataset, shuffle: bool = False, max_num_samples: int = 10000):
45 | out_path = '{}/readable_{}'.format(os.path.dirname(in_path), os.path.basename(in_path))
46 | out_path = out_path.replace('.jsonl.gz', '.json')
47 | dataset: Dataset = load_dataset('json', data_files=in_path, split='train', download_mode=DownloadMode.FORCE_REDOWNLOAD)
48 | if shuffle:
49 | dataset = dataset.shuffle()
50 | if len(dataset) > max_num_samples:
51 | dataset = dataset.select(range(max_num_samples))
52 | dataset = dataset.map(
53 | to_positive_negative_format,
54 | remove_columns=['doc_ids', 'doc_scores'],
55 | desc='to positive negative format'
56 | )
57 |
58 | max_to_keep = 5
59 |
60 | def _create_readable_field(samples: Dict[str, List]) -> List:
61 | readable_ex = []
62 | for idx in range(min(len(samples['doc_id']), max_to_keep)):
63 | doc_id = samples['doc_id'][idx]
64 | readable_ex.append({'doc_id': doc_id,
65 | 'contents': corpus[int(doc_id)]['contents'],
66 | 'score': samples['score'][idx],
67 | 'task_name': corpus[int(doc_id)]['task_name'],
68 | })
69 | return readable_ex
70 |
71 | def _mp_func(ex: Dict) -> Dict:
72 | ex['positives'] = _create_readable_field(ex['positives'])
73 | ex['negatives'] = _create_readable_field(ex['negatives'])
74 | return ex
75 | dataset = dataset.map(_mp_func, desc='to readable format')
76 |
77 | dataset.to_json(out_path, force_ascii=False, lines=False, indent=4)
78 | logger.info('Done convert {} to readable format in {}'.format(in_path, out_path))
79 |
80 |
81 | def save_llm_decoding_results(
82 | out_path: str,
83 | input_texts: List[str],
84 | decoded_texts: List[str],
85 | parsed_decoded_texts: List[str],
86 | options_list: List[List[str]],
87 | answer_texts: List[str]):
88 | assert len(input_texts) == len(decoded_texts)
89 | dataset = Dataset.from_dict({
90 | 'input_text': input_texts,
91 | 'decoded_text': decoded_texts,
92 | 'parsed_decoded_text': parsed_decoded_texts,
93 | 'options': options_list,
94 | 'answer_text': answer_texts
95 | })
96 | save_dataset(dataset, out_path, shuffle=True)
97 | logger.info('Successfully save decoding results to {}'.format(out_path))
98 |
99 |
100 | def log_task_statistics(ds: Dataset, split: str = 'train'):
101 | task_name_counter = Counter()
102 | for task_name in ds['task_name']:
103 | task_name_counter[task_name] += 1
104 | # log the number of examples per task
105 | for task_name, count in task_name_counter.most_common():
106 | logger.info('{} ({}): {}'.format(task_name, split, count))
107 | logger.info('{}: {} tasks, {} examples in total'.format(split, len(task_name_counter), len(ds)))
108 |
--------------------------------------------------------------------------------
/src/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_eval import BaseEval
2 | from .random_eval import RandomEval
3 | from .dense_eval import DenseEval
4 | from .bm25_eval import BM25Eval
5 | from .openai_eval import OpenaiEval
6 | from .baai_eval import BAAIEval
7 |
--------------------------------------------------------------------------------
/src/evaluation/baai_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Dict, Tuple
3 | from datasets import Dataset
4 | from evaluation.base_eval import BaseEval
5 | from config import Arguments
6 | from logger_config import logger
7 | import tiktoken
8 | from tqdm import tqdm
9 | import numpy as np
10 | import functools
11 | import signal
12 | import time
13 | import numpy as np
14 | from tqdm import tqdm
15 | import os
16 | import logging
17 | from numpy.linalg import norm
18 | from transformers import AutoTokenizer, AutoModel
19 | import torch
20 |
21 | class BAAIEval(BaseEval):
22 |
23 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs):
24 | super().__init__(args, corpus, **kwargs)
25 |
26 | self.corpus = corpus
27 | self.cache_file_dir = 'outputs/baai/'
28 |
29 | def get_baai_embeddings(self, cur_corpus, task_name):
30 | out_path = os.path.join(self.cache_file_dir,f'{task_name}_baai_embedding.npy')
31 | if os.path.exists(out_path):
32 | logger.info(f'{out_path} exists, directly loading it')
33 | embeddings = np.load(out_path)
34 | else:
35 | tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5')
36 | model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5')
37 | model.eval()
38 |
39 | embeddings = []
40 | with torch.no_grad():
41 | for entry in tqdm(cur_corpus):
42 | doc = [entry['contents']]
43 | encoded_input = tokenizer(doc, padding=True, truncation=True, return_tensors='pt')
44 | model_output = model(**encoded_input)
45 | embedding = model_output[0][:, 0]
46 | embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
47 | embeddings.append(embedding[0])
48 | embeddings = np.array(embeddings)
49 | np.save(out_path,embeddings)
50 | logger.info(f'save embedding to {out_path}')
51 | logger.info(f'embedding {out_path} dimension: {embeddings.shape}')
52 | return embeddings
53 |
54 | def compute_similarity(self, query_embedding, embeddings):
55 | similarity_scores = np.matmul(embeddings, np.transpose(query_embedding))
56 | return similarity_scores
57 |
58 |
59 | def get_topk_score_doc_ids(self, queries: List[str], k: int, task_names: List[str]) -> List[List[Tuple[float, str]]]:
60 |
61 | cur_corpus = self.corpus.filter(lambda x: x['task_name'] == task_names[0])
62 | embeddings = self.get_baai_embeddings(cur_corpus, task_names[0])
63 |
64 | tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5')
65 | model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5')
66 | model.eval()
67 |
68 | scores_all = []
69 |
70 | logger.info(f'len(queries): {len(queries)}')
71 | # logger.info(f'query_embedding.shape: {query_embeddings.shape}')
72 | logger.info(f'embeddings.shape: {embeddings.shape}')
73 |
74 | start_time = time.time()
75 | query_embeddings = []
76 | with torch.no_grad():
77 | for query in tqdm(queries):
78 | encoded_input = tokenizer([query], padding=True, truncation=True, return_tensors='pt')
79 | model_output = model(**encoded_input)
80 | sentence_embeddings = model_output[0][:, 0][0]
81 | query_embeddings.append(sentence_embeddings)
82 | end_time = time.time()
83 |
84 | logger.info(f'Wall clock time of query embeddings computation: {end_time-start_time}')
85 |
86 | scores_all = np.matmul(query_embeddings, embeddings.transpose())
87 |
88 | # for idx, query in tqdm(enumerate(queries)):
89 | # query_embedding = query_embeddings[idx]
90 | # scores = self.compute_similarity(query_embedding, embeddings)
91 | # scores_all.append(scores)
92 | # scores_all = np.array(scores_all)
93 |
94 | doc_ids = []
95 | for entry in cur_corpus:
96 | doc_ids.append(entry['id'])
97 |
98 | topk_index = scores_all.argsort(axis=1)[:, -k:]
99 | result = []
100 | for idx in range(len(topk_index)):
101 | cur_list = []
102 | for j in range(len(topk_index[idx])):
103 | doc_id_ = doc_ids[topk_index[idx][j]]
104 | doc_id_score = scores_all[idx][topk_index[idx][j]]
105 | cur_list.append((doc_id_score, doc_id_))
106 | result.append(cur_list)
107 | return result
--------------------------------------------------------------------------------
/src/evaluation/base_eval.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import List, Dict, Set, Tuple
3 | from collections import defaultdict
4 | from datasets import Dataset
5 |
6 | from config import Arguments
7 |
8 |
9 | class BaseEval:
10 |
11 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs):
12 | self.args: Arguments = args
13 | # id / contents / task_name
14 | self.corpus: Dataset = corpus
15 |
16 | self.task_name_to_doc_ids: Dict[str, Set[str]] = defaultdict(set)
17 | for doc_id, task_name in zip(self.corpus['id'], self.corpus['task_name']):
18 | self.task_name_to_doc_ids[task_name].add(doc_id)
19 |
20 | @abstractmethod
21 | def get_topk_score_doc_ids(
22 | self, queries: List[str], k: int, task_names: List[str]
23 | ) -> List[List[Tuple[float, str]]]:
24 | raise NotImplementedError
25 |
26 | def get_doc_ids_by_task_name(self, task_name: str) -> List[str]:
27 | return list(self.task_name_to_doc_ids[task_name])
28 |
29 | def get_prompt_by_doc_ids(self, doc_ids: List[str]) -> str:
30 | return '\n\n'.join([self.corpus[int(doc_id)]['contents'] for doc_id in doc_ids])
31 |
--------------------------------------------------------------------------------
/src/evaluation/bm25_eval.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Tuple
2 | from datasets import Dataset
3 |
4 | from evaluation.base_eval import BaseEval
5 | from config import Arguments
6 | from logger_config import logger
7 | from rank_bm25 import BM25Okapi
8 | from tqdm import tqdm
9 | import os
10 | import numpy as np
11 |
12 | class BM25Eval(BaseEval):
13 |
14 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs):
15 | super().__init__(args, corpus, **kwargs)
16 |
17 | self.corpus = corpus
18 | # initializing punctuations string
19 | self.punc = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
20 |
21 | self.cache_file_dir = 'outputs/bm25/'
22 |
23 | def tokenize(self, str):
24 | str = str.lower()
25 | for ele in str:
26 | if ele in self.punc:
27 | str = str.replace(ele, "")
28 | doc_tokens = str.split(" ")
29 | return doc_tokens
30 |
31 |
32 |
33 | def get_topk_score_doc_ids(self, queries: List[str], k: int, task_names: List[str]) -> List[List[Tuple[float, str]]]:
34 | tokenized_corpus = []
35 | doc_ids = []
36 | logger.info('Start tokenizing corpus for BM25')
37 | cur_corpus = self.corpus.filter(lambda x: x['task_name'] == task_names[0])
38 |
39 | outpath_scores = os.path.join(self.cache_file_dir,f'{task_names[0]}_bm25_scores.npy')
40 | outpath_ids = os.path.join(self.cache_file_dir, f'{task_names[0]}_bm25_docids.npy')
41 |
42 | if os.path.exists(outpath_scores):
43 | logger.info(f'score file {outpath_scores} already exists, skip calculating')
44 | doc_scores_all = np.load(outpath_scores)
45 | doc_ids = np.load(outpath_ids)
46 | else:
47 | for entry in tqdm(cur_corpus):
48 | doc = entry['contents']
49 | doc_tokens = self.tokenize(doc)
50 | tokenized_corpus.append(doc_tokens)
51 | doc_ids.append(entry['id'])
52 |
53 | bm25 = BM25Okapi(tokenized_corpus)
54 | doc_scores_all = []
55 | for query in tqdm(queries):
56 | tokenized_query = self.tokenize(query)
57 | doc_scores = bm25.get_scores(tokenized_query)
58 | doc_scores_all.append(doc_scores)
59 |
60 | doc_scores_all = np.array(doc_scores_all)
61 | doc_ids = np.array(doc_ids)
62 |
63 | np.save(outpath_scores, doc_scores_all)
64 | np.save(outpath_ids, doc_ids)
65 |
66 | topk_index = doc_scores_all.argsort(axis=1)[:,-k:]
67 | result = []
68 | for idx in range(len(topk_index)):
69 | cur_list = []
70 | for j in range(len(topk_index[idx])):
71 | doc_id_ = doc_ids[topk_index[idx][j]]
72 | doc_id_score = doc_scores_all[idx][topk_index[idx][j]]
73 | cur_list.append((doc_id_score, doc_id_))
74 | result.append(cur_list)
75 | return result
76 |
--------------------------------------------------------------------------------
/src/evaluation/dense_eval.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Tuple
2 | from datasets import Dataset
3 |
4 | from evaluation.base_eval import BaseEval
5 | from config import Arguments
6 | from logger_config import logger
7 |
8 |
9 | class DenseEval(BaseEval):
10 |
11 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs):
12 | super().__init__(args, corpus, **kwargs)
13 |
14 | input_prefix = 'query: ' if args.add_qd_prompt else ''
15 | # TODO: Hack
16 | is_e5_model = any(e5_name in args.model_name_or_path for e5_name in ['intfloat/e5', 'intfloat/multilingual-e5'])
17 | if is_e5_model and not input_prefix:
18 | logger.warning('E5 models need input prefix, set input_prefix = "query: "')
19 | input_prefix = 'query: '
20 |
21 | from models import SimpleEncoder, SimpleRetriever
22 | encoder: SimpleEncoder = SimpleEncoder(
23 | model_name_or_path=args.model_name_or_path,
24 | l2_normalize=args.l2_normalize,
25 | prompt=input_prefix,
26 | )
27 | cache_dir = '{}/embeddings/'.format(args.output_dir)
28 |
29 | self.retriever: SimpleRetriever = SimpleRetriever(
30 | encoder=encoder,
31 | corpus=corpus,
32 | cache_dir=cache_dir,
33 | )
34 |
35 | def get_topk_score_doc_ids(self, queries: List[str], k: int, task_names: List[str]) -> List[List[Tuple[float, str]]]:
36 | assert len(queries) == len(task_names)
37 |
38 | query_idx_to_topk: Dict[int, List[Tuple]] = self.retriever.search_topk(queries=queries, top_k=k)
39 | for idx in range(len(queries)):
40 | q_task_name = task_names[idx]
41 | for j, (score, doc_id) in enumerate(query_idx_to_topk[idx]):
42 | if str(doc_id) not in self.task_name_to_doc_ids[q_task_name]:
43 | query_idx_to_topk[idx][j] = (score - 100., doc_id)
44 | query_idx_to_topk[idx] = sorted(query_idx_to_topk[idx], key=lambda x: x[0], reverse=True)
45 |
46 | topk_score_doc_ids: List[List[Tuple[float, str]]] = []
47 | for idx in range(len(queries)):
48 | score_doc_ids: List[Tuple[float, str]] = query_idx_to_topk[idx][:k]
49 | score_doc_ids = [(score, str(doc_id)) for score, doc_id in score_doc_ids]
50 | topk_score_doc_ids.append(score_doc_ids)
51 |
52 | return topk_score_doc_ids
53 |
--------------------------------------------------------------------------------
/src/evaluation/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from typing import List
4 | from rouge import Rouge
5 | from sklearn.metrics import f1_score
6 |
7 | from evaluation import qa_utils
8 | from logger_config import logger
9 |
10 |
11 | @torch.no_grad()
12 | def accuracy(output: torch.tensor, target: torch.tensor, topk=(1,)) -> List[float]:
13 | """Computes the accuracy over the k top predictions for the specified values of k"""
14 | maxk = max(topk)
15 | batch_size = target.size(0)
16 |
17 | _, pred = output.topk(maxk, 1, True, True)
18 | pred = pred.t()
19 | correct = pred.eq(target.view(1, -1).expand_as(pred))
20 |
21 | res = []
22 | for k in topk:
23 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
24 | res.append(correct_k.mul_(100.0 / batch_size).item())
25 | return res
26 |
27 |
28 | @torch.no_grad()
29 | def batch_mrr(output: torch.tensor, target: torch.tensor) -> float:
30 | assert len(output.shape) == 2
31 | assert len(target.shape) == 1
32 | sorted_score, sorted_indices = torch.sort(output, dim=-1, descending=True)
33 | _, rank = torch.nonzero(sorted_indices.eq(target.unsqueeze(-1)).long(), as_tuple=True)
34 | assert rank.shape[0] == output.shape[0]
35 |
36 | rank = rank + 1
37 | mrr = torch.sum(100 / rank.float()) / rank.shape[0]
38 | return mrr.item()
39 |
40 |
41 | ## =========================================================================== ##
42 |
43 | # Copy from https://github.com/microsoft/LMOps/tree/main/uprise/src/utils
44 |
45 |
46 | def rouge(preds, labels):
47 | # https://github.com/pltrdy/rouge
48 | r1s, r2s, rls = [], [], []
49 | r = Rouge()
50 | for i in range(len(labels)):
51 | if '\n' not in preds[i]: preds[i] += '\n'
52 | if '\n' not in labels[i]: labels[i] += '\n' # avoid empty string
53 | scores = r.get_scores(preds[i], labels[i])[0]
54 | r1s.append(scores["rouge-1"]['f'])
55 | r2s.append(scores["rouge-2"]['f'])
56 | rls.append(scores["rouge-l"]['f'])
57 | r1 = sum(r1s) / len(r1s)
58 | r2 = sum(r2s) / len(r2s)
59 | rl = sum(rls) / len(rls)
60 | return r1, r2, rl
61 |
62 |
63 | def squad(labels, preds):
64 | """Computes SQuAD metrics, maximizing over answers per question.
65 | Args:
66 | labels: list of lists of strings
67 | preds: list of strings
68 | Returns:
69 | dict with score_key: squad score across all labels and predictions
70 | """
71 | labels = [[qa_utils.normalize_squad(t) for t in u] for u in labels]
72 | preds = [qa_utils.normalize_squad(p) for p in preds]
73 | em, f1 = qa_utils.qa_metrics(labels, preds) # em,f1
74 | return em, f1
75 |
76 |
77 | def trivia_qa(labels, preds):
78 | """Computes TriviaQA metrics, maximizing over answers per question.
79 | Args:
80 | labels: list of lists of strings
81 | preds: list of strings
82 | Returns:
83 | dict with score_key: squad score across all labels and preds
84 | """
85 | labels = [[qa_utils.normalize_trivia_qa(t) for t in u] for u in labels]
86 | preds = [qa_utils.normalize_trivia_qa(p) for p in preds]
87 | em, f1 = qa_utils.qa_metrics(labels, preds) # em,f1
88 | return em, f1
89 |
90 |
91 | def simple_accuracy(preds, labels):
92 | if isinstance(preds[0], str):
93 | labels = [label.lower().strip() for label in labels]
94 | preds = [pred.lower().strip() for pred in preds]
95 | res = [int(preds[i] == labels[i]) for i in range(len(preds))]
96 | acc = sum(res) / len(res)
97 | return acc
98 |
99 |
100 | def acc_and_f1(preds, labels):
101 | acc = simple_accuracy(preds, labels)
102 | # Currently only MRPC & QQP use this metric
103 | f1 = f1_score(y_true=labels, y_pred=preds, pos_label='Yes')
104 | return acc, f1, (acc + f1) / 2
105 |
106 |
107 | def compute_metrics(metric, labels, preds):
108 | assert len(preds) == len(labels)
109 | if metric == 'simple_accuracy':
110 | return {'acc': simple_accuracy(preds, labels) * 100}
111 | elif metric == 'rouge':
112 | r1, r2, rl = rouge(preds, labels)
113 | return {'r1': r1 * 100, 'r2': r2 * 100, 'rl': rl * 100}
114 | elif metric == 'acc_and_f1':
115 | acc, f1, acc_f1 = acc_and_f1(preds, labels)
116 | return {'acc': acc * 100, 'f1': f1 * 100, 'acc_and_f1': acc_f1 * 100}
117 | elif metric == 'f1':
118 | f1 = f1_score(y_true=labels, y_pred=preds, pos_label='Yes')
119 | return {'f1': f1 * 100}
120 | elif metric == 'squad':
121 | em, f1 = squad(labels=labels, preds=preds)
122 | return {'em': em, 'f1': f1}
123 | elif metric == 'trivia_qa':
124 | em, f1 = trivia_qa(labels=labels, preds=preds)
125 | return {'em': em, 'f1': f1}
126 | else:
127 | raise ValueError(metric)
128 |
--------------------------------------------------------------------------------
/src/evaluation/qa_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The T5 Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for Question Answering (QA) evaluation.
16 |
17 | Matches results on the SQuAD (v1.1) and TriviaQA (v1.0) evaluation scripts.
18 | """
19 |
20 | import collections
21 | import re
22 | import string
23 |
24 | from absl import logging
25 | import numpy as np
26 |
27 |
28 | def _normalize_answer(text, punc_chars, punc_repl):
29 | """Lower text and remove punctuation, articles and extra whitespace."""
30 |
31 | def remove_articles(s):
32 | return re.sub(r"\b(a|an|the)\b", " ", s)
33 |
34 | def replace_punctuation(s):
35 | to_replace = set(punc_chars)
36 | return "".join(punc_repl if ch in to_replace else ch for ch in s)
37 |
38 | def white_space_fix(s):
39 | return " ".join(s.split())
40 |
41 | text = text.lower()
42 | text = replace_punctuation(text)
43 | text = remove_articles(text)
44 | text = white_space_fix(text)
45 |
46 | return text
47 |
48 |
49 | def normalize_trivia_qa(answer):
50 | """Normalization used in official TriviaQA evaluation script."""
51 | return _normalize_answer(
52 | answer, punc_chars=string.punctuation + "‘’´`_", punc_repl=" ").strip()
53 |
54 |
55 | def normalize_squad(answer):
56 | """Normalization used in official SQuAD evaluation script."""
57 | return _normalize_answer(answer, punc_chars=string.punctuation, punc_repl="")
58 |
59 |
60 | def _metric_max_over_ground_truths(metric_fn, ground_truths, prediction):
61 | """Computes the maximum of the metric over all ground truths."""
62 | return max(
63 | metric_fn(ground_truth, prediction) for ground_truth in ground_truths
64 | )
65 |
66 |
67 | def _exact_match_score(target, prediction):
68 | return target == prediction
69 |
70 |
71 | def _f1_score(target, prediction):
72 | """Computes token f1 score for a single target and prediction."""
73 | prediction_tokens = prediction.split()
74 | target_tokens = target.split()
75 | common = (collections.Counter(prediction_tokens) &
76 | collections.Counter(target_tokens))
77 | num_same = sum(common.values())
78 | if num_same == 0:
79 | return 0
80 | precision = 1.0 * num_same / len(prediction_tokens)
81 | recall = 1.0 * num_same / len(target_tokens)
82 | f1 = (2 * precision * recall) / (precision + recall)
83 | return f1
84 |
85 | def qa_metrics(targets, predictions):
86 | """Computes exact match and f1 QA scores, expecting pre-normalized text."""
87 | if len(targets) != len(predictions):
88 | raise ValueError("Number of targets and predictions must match.")
89 | em = np.mean([
90 | _metric_max_over_ground_truths(_exact_match_score, t, p)
91 | for p, t in zip(predictions, targets)
92 | ])
93 | f1 = np.mean([
94 | _metric_max_over_ground_truths(_f1_score, t, p)
95 | for p, t in zip(predictions, targets)
96 | ])
97 | em *= 100
98 | f1 *= 100
99 | logging.info("EM = %.2f, F1 = %.2f", em, f1)
100 | return em, f1
--------------------------------------------------------------------------------
/src/evaluation/random_eval.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from typing import List, Tuple, Dict
4 | from datasets import Dataset
5 |
6 | from evaluation.base_eval import BaseEval
7 | from config import Arguments
8 | from logger_config import logger
9 |
10 |
11 | # This class randomly selects k-shot examples from the training set
12 | class RandomEval(BaseEval):
13 |
14 | def __init__(self, args: Arguments, corpus: Dataset, **kwargs):
15 | super().__init__(args, corpus, **kwargs)
16 | self.all_doc_ids: List[str] = self.corpus['id']
17 | self.cached_task_name_to_doc_ids: Dict[str, List[str]] = {}
18 |
19 | def get_topk_score_doc_ids(self, queries: List[str], k: int, task_names: List[str]) -> List[List[Tuple[float, str]]]:
20 | assert len(queries) == len(task_names)
21 |
22 | topk_score_doc_ids: List[List[Tuple[float, str]]] = []
23 | for query, task_name in zip(queries, task_names):
24 | random_score_doc_ids: List[str] = self._single_get_topk_doc_ids(query, k, task_name)
25 | topk_score_doc_ids.append([(-1, doc_id) for doc_id in random_score_doc_ids])
26 |
27 | return topk_score_doc_ids
28 |
29 | def _single_get_topk_doc_ids(self, query: str, k: int, task_name: str) -> List[str]:
30 | if task_name not in self.cached_task_name_to_doc_ids:
31 | self.cached_task_name_to_doc_ids[task_name] = self.get_doc_ids_by_task_name(task_name)
32 | doc_ids: List[str] = self.cached_task_name_to_doc_ids[task_name]
33 | # mnli_m & mnli_mm should retrieve from mnli training set
34 | if len(doc_ids) == 0 and task_name.startswith('mnli_'):
35 | if 'mnli' not in self.cached_task_name_to_doc_ids:
36 | self.cached_task_name_to_doc_ids['mnli'] = self.get_doc_ids_by_task_name('mnli')
37 | doc_ids = self.cached_task_name_to_doc_ids['mnli']
38 |
39 | if len(doc_ids) == 0:
40 | logger.warning('Use the whole training set for task: {}'.format(task_name))
41 | doc_ids = self.all_doc_ids
42 |
43 | if k >= len(doc_ids):
44 | logger.warning('k ({}) is larger than the number of examples ({})'.format(k, len(doc_ids)))
45 | k = min(k, len(doc_ids))
46 |
47 | return random.sample(doc_ids, k)
48 |
--------------------------------------------------------------------------------
/src/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kangmintong/C-RAG/5bda4382bebf6453b3b8dccb8a52f4d015e95fee/src/inference/__init__.py
--------------------------------------------------------------------------------
/src/inference/gen_reward_scores.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import tqdm
4 | import torch
5 |
6 | from contextlib import nullcontext
7 | from torch.utils.data import DataLoader
8 | from functools import partial
9 | from datasets import Dataset, load_dataset
10 | from typing import Dict, List, Tuple
11 | from transformers.modeling_outputs import SequenceClassifierOutput
12 | from transformers import (
13 | AutoTokenizer,
14 | PreTrainedTokenizerFast,
15 | DataCollatorWithPadding,
16 | HfArgumentParser,
17 | )
18 |
19 | from config import Arguments
20 | from logger_config import logger
21 | from utils import move_to_device, save_dataset, wait_until_all_files_show_up
22 | from models import RerankerForInference
23 | from data_utils import load_corpus, save_to_readable_format
24 | from inference.inference_utils import reward_transform_func
25 |
26 | parser = HfArgumentParser((Arguments,))
27 | args: Arguments = parser.parse_args_into_dataclasses()[0]
28 | kd_gen_score_in_path = os.path.join(args.data_dir, '{}.jsonl.gz'.format(args.kd_gen_score_split))
29 | kd_gen_score_out_path = os.path.join(args.output_dir, 'kd_{}.jsonl.gz'.format(args.kd_gen_score_split))
30 |
31 |
32 | def _get_shard_path(worker_idx: int) -> str:
33 | basename = os.path.basename(kd_gen_score_in_path)
34 | return '{}/shard_{}_{}'.format(args.output_dir, worker_idx, basename)
35 |
36 |
37 | @torch.no_grad()
38 | def _worker_gen_teacher_score():
39 | gpu_idx: int = args.process_index
40 | dataset = load_dataset('json', data_files=kd_gen_score_in_path)['train']
41 | if args.dry_run:
42 | dataset = dataset.select(range(100))
43 | dataset = dataset.shard(num_shards=args.world_size,
44 | index=gpu_idx,
45 | contiguous=True)
46 |
47 | qid_pids = []
48 | for ex in tqdm.tqdm(dataset, desc='get qid-pid pairs', mininterval=3):
49 | for doc_id in ex['doc_ids']:
50 | qid_pids.append((ex['query_id'], doc_id, ex['query'], ex['answers'], ex['options']))
51 |
52 | inference_dataset = Dataset.from_dict({
53 | 'query_id': [t[0] for t in qid_pids],
54 | 'doc_id': [t[1] for t in qid_pids],
55 | 'query': [t[2] for t in qid_pids],
56 | 'answers': [t[3] for t in qid_pids],
57 | 'options': [t[4] for t in qid_pids],
58 | })
59 |
60 | query_ids, doc_ids = inference_dataset['query_id'], inference_dataset['doc_id']
61 |
62 | logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(inference_dataset)))
63 |
64 | tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
65 | model: RerankerForInference = RerankerForInference.from_pretrained(args.model_name_or_path)
66 | model.eval()
67 | model.to(gpu_idx)
68 |
69 | corpus: Dataset = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz'))
70 | inference_dataset.set_transform(partial(reward_transform_func, tokenizer, args.reward_max_length, corpus))
71 |
72 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
73 | data_loader = DataLoader(
74 | inference_dataset,
75 | batch_size=args.kd_gen_score_batch_size,
76 | shuffle=False,
77 | drop_last=False,
78 | num_workers=args.dataloader_num_workers,
79 | collate_fn=data_collator,
80 | pin_memory=True)
81 |
82 | scores = []
83 | for batch_dict in tqdm.tqdm(data_loader, desc='generate teacher score', mininterval=5):
84 | batch_dict = move_to_device(batch_dict, device=gpu_idx)
85 |
86 | with torch.cuda.amp.autocast() if args.fp16 else nullcontext():
87 | outputs: SequenceClassifierOutput = model(batch_dict)
88 | scores.append(outputs.logits.squeeze(dim=-1).cpu())
89 | assert len(scores[-1].shape) == 1
90 |
91 | all_scores = torch.cat(scores, dim=-1)
92 | assert all_scores.shape[0] == len(inference_dataset), '{} != {}'.format(all_scores.shape[0], len(inference_dataset))
93 | all_scores = all_scores.tolist()
94 |
95 | query_id_to_doc_id_scores: Dict[str, List[Tuple[str, float]]] = collections.defaultdict(list)
96 | for idx in range(len(query_ids)):
97 | query_id_to_doc_id_scores[query_ids[idx]].append((doc_ids[idx], round(all_scores[idx], 5)))
98 |
99 | def _update_score(ex: Dict) -> Dict:
100 | query_id = ex['query_id']
101 | ex['doc_ids'] = [t[0] for t in query_id_to_doc_id_scores[query_id]]
102 | ex['doc_scores'] = [t[1] for t in query_id_to_doc_id_scores[query_id]]
103 | return ex
104 |
105 | dataset = dataset.map(_update_score, batched=False)
106 | save_dataset(dataset, _get_shard_path(gpu_idx))
107 |
108 | logger.info('Done computing teacher score for worker {}'.format(gpu_idx))
109 |
110 |
111 | def _merge_teacher_scores():
112 | wait_until_all_files_show_up([_get_shard_path(worker_idx) for worker_idx in range(args.world_size)])
113 |
114 | dataset = load_dataset(
115 | 'json', data_files=[_get_shard_path(worker_idx) for worker_idx in range(args.world_size)], split='train'
116 | )
117 |
118 | save_dataset(dataset, kd_gen_score_out_path)
119 | logger.info('Writing teacher score to {}'.format(kd_gen_score_out_path))
120 |
121 | logger.info('Done merge results')
122 |
123 | corpus = load_corpus(path=os.path.join(args.data_dir, 'passages.jsonl.gz'))
124 | save_to_readable_format(in_path=kd_gen_score_out_path, corpus=corpus, shuffle=True)
125 |
126 | for worker_idx in range(args.world_size):
127 | os.remove(_get_shard_path(worker_idx))
128 |
129 |
130 | def main():
131 | logger.info('Args={}'.format(str(args)))
132 | if os.path.exists(kd_gen_score_out_path):
133 | logger.info('Found {}, skip'.format(kd_gen_score_out_path))
134 | return
135 |
136 | logger.info('Use {} workers'.format(args.world_size))
137 | _worker_gen_teacher_score()
138 | logger.info('Done batch generate teacher score')
139 |
140 | if args.process_index <= 0:
141 | _merge_teacher_scores()
142 |
143 |
144 | if __name__ == '__main__':
145 | main()
146 |
--------------------------------------------------------------------------------
/src/inference/generate_few_shot_prompt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, 'src/')
4 |
5 | from datasets import Dataset, load_dataset, DownloadMode, concatenate_datasets
6 | from typing import List, Tuple
7 | from transformers import HfArgumentParser
8 |
9 | from config import Arguments
10 | from logger_config import logger
11 | from utils import save_dataset
12 | from evaluation import BaseEval
13 | from model_utils import build_eval_model
14 | from inference.inference_utils import get_prompt_save_path
15 |
16 | parser = HfArgumentParser((Arguments,))
17 | args: Arguments = parser.parse_args_into_dataclasses()[0]
18 |
19 |
20 | def main():
21 | out_path: str = get_prompt_save_path(args=args)
22 |
23 | if args.llm_k_shot==0:
24 | args.llm_k_shot=1
25 |
26 | if os.path.exists(out_path):
27 | logger.info('Prompt file {} exists. Skip.'.format(out_path))
28 | return
29 |
30 | corpus: Dataset = load_dataset(
31 | 'json', data_files='{}/passages.jsonl.gz'.format(args.data_dir), split='train',
32 | download_mode=DownloadMode.FORCE_REDOWNLOAD
33 | )
34 | # columns: query_id / query / answers / task_name
35 | eval_dataset: Dataset = load_dataset(
36 | 'json', data_files='{}/{}.jsonl.gz'.format(args.data_dir, args.llm_eval_split), split='train',
37 | download_mode=DownloadMode.FORCE_REDOWNLOAD
38 | )
39 |
40 |
41 | if not args.llm_eval_tasks or args.llm_eval_tasks[0] == 'all':
42 | args.llm_eval_tasks = sorted(eval_dataset.unique('task_name'))
43 | logger.info('Eval all {} tasks'.format(len(args.llm_eval_tasks)))
44 |
45 | logger.info(args.llm_eval_tasks)
46 |
47 | model: BaseEval = build_eval_model(args=args, corpus=corpus)
48 |
49 | task_ds_list: List[Dataset] = []
50 | for task_name in args.llm_eval_tasks:
51 | task_ds: Dataset = eval_dataset.filter(lambda x: x['task_name'] == task_name)
52 |
53 | if len(task_ds) > args.max_test_samples:
54 | logger.info('Task: {}, random sample {}/{} for evaluation'.format(
55 | task_name, args.max_test_samples, len(task_ds))
56 | )
57 | task_ds = task_ds.shuffle(seed=args.seed).select(range(args.max_test_samples))
58 | logger.info('Task: {}, {} samples for evaluation'.format(task_name, len(task_ds)))
59 |
60 | if args.llm_k_shot <= 0:
61 | task_ds = task_ds.add_column('input_prompt', ['' for _ in range(len(task_ds))])
62 | task_ds_list.append(task_ds)
63 | continue
64 |
65 | queries: List[str] = task_ds['query']
66 | # Use a larger k in case we retrieve the docs from other tasks
67 | logger.info(f'args.llm_k_shot: {args.llm_k_shot}')
68 | topk_score_doc_ids: List[List[Tuple[float, str]]] = model.get_topk_score_doc_ids(
69 | queries, k=5 * args.llm_k_shot, task_names=task_ds['task_name']
70 | )
71 | # The most relevant doc should be close to the test example
72 | topk_score_doc_ids = [score_doc_ids[:args.llm_k_shot][::-1] for score_doc_ids in topk_score_doc_ids]
73 | topk_doc_ids: List[List[str]] = [
74 | [doc_id for _, doc_id in score_doc_ids] for score_doc_ids in topk_score_doc_ids
75 | ]
76 | topk_scores: List[List[float]] = [
77 | [round(score, 4) for score, _ in score_doc_ids] for score_doc_ids in topk_score_doc_ids
78 | ]
79 | input_prompts: List[str] = [model.get_prompt_by_doc_ids(doc_ids) for doc_ids in topk_doc_ids]
80 | task_ds = task_ds.add_column('input_prompt', input_prompts)
81 | task_ds = task_ds.add_column('topk_doc_ids', topk_doc_ids)
82 | task_ds = task_ds.add_column('topk_scores', topk_scores)
83 | task_ds_list.append(task_ds)
84 |
85 | few_shot_ds: Dataset = concatenate_datasets(task_ds_list)
86 | save_dataset(few_shot_ds, out_path)
87 | logger.info('Save {} examples to {}'.format(len(few_shot_ds), out_path))
88 |
89 |
90 | if __name__ == '__main__':
91 | main()
92 |
--------------------------------------------------------------------------------
/src/inference/inference_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from typing import Dict, List
4 | from datasets import Dataset
5 | from transformers.file_utils import PaddingStrategy
6 | from transformers import PreTrainedTokenizerFast, BatchEncoding
7 |
8 | from config import Arguments
9 |
10 |
11 | def get_prompt_save_path(args: Arguments) -> str:
12 | from model_utils import parse_model_id
13 | model_id: str = parse_model_id(args.model_name_or_path)
14 | out_path: str = '{}/{}_{}_k{}.jsonl.gz'.format(
15 | args.output_dir, model_id, args.llm_eval_split, args.llm_k_shot
16 | )
17 |
18 | return out_path
19 |
20 |
21 | def reward_transform_func(
22 | tokenizer: PreTrainedTokenizerFast,
23 | reward_max_length: int,
24 | corpus: Dataset,
25 | examples: Dict[str, List]) -> BatchEncoding:
26 | input_docs: List[str] = []
27 |
28 | # ATTENTION: this code should be consistent with RerankDataLoader
29 | for doc_id in examples['doc_id']:
30 | doc_id = int(doc_id)
31 | input_docs.append(corpus[doc_id]['contents'].strip())
32 |
33 | input_queries = []
34 | for query, answers, options in zip(examples['query'], examples['answers'], examples['options']):
35 | current_query = query
36 | if len(options) > 1:
37 | current_query += '\n' + options[ord(answers[0]) - ord('A')]
38 | else:
39 | current_query += '\n' + random.choice(answers)
40 | input_queries.append(current_query)
41 |
42 | batch_dict = tokenizer(input_queries,
43 | text_pair=input_docs,
44 | max_length=reward_max_length,
45 | padding=PaddingStrategy.DO_NOT_PAD,
46 | return_token_type_ids=False,
47 | truncation=True)
48 |
49 | return batch_dict
50 |
--------------------------------------------------------------------------------
/src/inference/search_topk.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | sys.path.insert(0, 'src/')
5 |
6 | from datasets import Dataset, load_dataset, DownloadMode
7 | from typing import Dict, List, Tuple
8 | from transformers import HfArgumentParser
9 |
10 | from config import Arguments
11 | from logger_config import logger
12 | from utils import save_dataset
13 | from data_utils import save_to_readable_format
14 | from evaluation import BaseEval
15 | from model_utils import build_eval_model, parse_model_id
16 |
17 | parser = HfArgumentParser((Arguments,))
18 | args: Arguments = parser.parse_args_into_dataclasses()[0]
19 | assert args.do_search, 'This script is only for search mode.'
20 |
21 |
22 | def main():
23 | model_id = parse_model_id(args.model_name_or_path)
24 | out_path: str = '{}/{}_{}.jsonl.gz'.format(args.output_dir, model_id, args.search_split)
25 | if os.path.exists(out_path):
26 | logger.info('Output file {} already exists. Skip.'.format(out_path))
27 | return
28 |
29 | data_path: str = '{}/{}.jsonl.gz'.format(args.data_dir, args.search_split)
30 | assert os.path.exists(data_path), 'Data file {} does not exist.'.format(data_path)
31 | dataset: Dataset = load_dataset(
32 | 'json', data_files=data_path, split='train', download_mode=DownloadMode.FORCE_REDOWNLOAD
33 | )
34 | if args.dry_run:
35 | dataset = dataset.shuffle(seed=args.seed).select(range(100))
36 | logger.info('Load {} examples from {}'.format(len(dataset), data_path))
37 |
38 | corpus_path: str = '{}/passages.jsonl.gz'.format(args.data_dir)
39 | corpus: Dataset = load_dataset(
40 | 'json', data_files=corpus_path, split='train', download_mode=DownloadMode.FORCE_REDOWNLOAD
41 | )
42 | if args.dry_run:
43 | corpus = corpus.select(range(4096))
44 | logger.info('Load {} candidates from {}'.format(len(corpus), corpus_path))
45 |
46 | retriever: BaseEval = build_eval_model(args=args, corpus=corpus)
47 |
48 | logger.info('Search top {} candidates for each example.'.format(args.search_topk))
49 | topk_score_doc_ids: List[List[Tuple[float, str]]] = retriever.get_topk_score_doc_ids(
50 | dataset['query'], k=args.search_topk, task_names=dataset['task_name']
51 | )
52 | all_contents: List[str] = corpus['contents']
53 |
54 | def _map_func(example: Dict, idx: int) -> Dict:
55 | score_doc_ids: List[Tuple[float, str]] = topk_score_doc_ids[idx]
56 | # Exclude the example itself from the top-k candidates.
57 | score_doc_ids = [t for t in score_doc_ids if not all_contents[int(t[1])].startswith(example['query'])]
58 | np.random.shuffle(score_doc_ids)
59 | return {
60 | 'doc_ids': [doc_id for _, doc_id in score_doc_ids],
61 | 'doc_scores': [round(doc_score, 4) for doc_score, _ in score_doc_ids],
62 | }
63 |
64 | dataset = dataset.map(_map_func, with_indices=True, num_proc=1, desc='Add top-k candidates')
65 | dataset = dataset.filter(lambda example: len(example['doc_ids']) > 1)
66 |
67 | save_dataset(dataset, out_path, shuffle='train' in args.search_split)
68 | save_to_readable_format(in_path=out_path, corpus=corpus, shuffle=True)
69 |
70 |
71 | if __name__ == '__main__':
72 | main()
73 |
--------------------------------------------------------------------------------
/src/llms/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_llm import BaseLLM
2 | from .gpt2 import GPT2
3 | from .gpt_neo import GPTNeo
4 | from .llama import Llama
5 |
--------------------------------------------------------------------------------
/src/llms/base_llm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from abc import abstractmethod
5 | from typing import List, Optional, Union
6 |
7 |
8 | class BaseLLM(nn.Module):
9 |
10 | def __init__(self, model_name_or_path: str, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.model_name_or_path = model_name_or_path
13 |
14 | @abstractmethod
15 | def batch_score(self, input_texts: List[str], output_texts: List[str], **kwargs) -> List[float]:
16 | raise NotImplementedError
17 |
18 | def score(self, input_text: str, output_text: str, **kwargs) -> float:
19 | return self.batch_score([input_text], [output_text], **kwargs)[0]
20 |
21 | @abstractmethod
22 | def batch_decode(self, input_texts: List[str], **kwargs) -> List[str]:
23 | raise NotImplementedError
24 |
25 | def decode(self, input_text: str, **kwargs) -> str:
26 | return self.batch_decode([input_text], **kwargs)[0]
27 |
28 | def cuda(self, device: Optional[Union[int, torch.device]] = 0):
29 | self.model.to(device)
30 | return self
31 |
--------------------------------------------------------------------------------
/src/llms/gpt_neo.py:
--------------------------------------------------------------------------------
1 | from logger_config import logger
2 | from config import Arguments
3 | from llms.gpt2 import GPT2
4 |
5 |
6 | class GPTNeo(GPT2):
7 |
8 | def __init__(self, args: Arguments, model_name_or_path: str = 'EleutherAI/gpt-neo-2.7B', **kwargs):
9 | super().__init__(args, model_name_or_path, **kwargs)
10 |
--------------------------------------------------------------------------------
/src/llms/llama.py:
--------------------------------------------------------------------------------
1 | from logger_config import logger
2 | from config import Arguments
3 | from llms.gpt2 import GPT2
4 |
5 |
6 | class Llama(GPT2):
7 |
8 | def __init__(self, args: Arguments, model_name_or_path: str = 'huggyllama/llama-7b', **kwargs):
9 | super().__init__(args, model_name_or_path, **kwargs)
10 |
--------------------------------------------------------------------------------
/src/loaders/__init__.py:
--------------------------------------------------------------------------------
1 | from .biencoder_dataloader import RetrievalDataLoader
2 | from .cross_encoder_dataloader import CrossEncoderDataLoader
3 |
--------------------------------------------------------------------------------
/src/loaders/cross_encoder_dataloader.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import torch
4 |
5 | from copy import deepcopy
6 | from functools import partial
7 | from typing import Dict, List, Optional
8 | from datasets import load_dataset, Dataset
9 | from transformers.file_utils import PaddingStrategy
10 | from transformers import PreTrainedTokenizerFast, Trainer
11 |
12 | from config import Arguments
13 | from logger_config import logger
14 | from utils import get_input_files
15 | from .loader_utils import group_doc_ids, filter_invalid_examples
16 | from data_utils import to_positive_negative_format
17 |
18 |
19 | class CrossEncoderDataset(torch.utils.data.Dataset):
20 |
21 | def __init__(self, input_files: List[str], args: Arguments,
22 | tokenizer: PreTrainedTokenizerFast):
23 | self.args = args
24 | self.input_files = input_files
25 | self.negative_size = args.train_n_passages - 1
26 | assert self.negative_size > 0
27 | self.tokenizer = tokenizer
28 | corpus_path = os.path.join(os.path.dirname(self.input_files[0]), 'passages.jsonl.gz')
29 | self.corpus: Dataset = load_dataset('json', data_files=corpus_path, split='train')
30 |
31 | self.dataset: Dataset = load_dataset('json', data_files=self.input_files, split='train')
32 | with self.args.main_process_first(desc="pre-processing"):
33 | self.dataset = filter_invalid_examples(args, self.dataset)
34 | self.dataset = self.dataset.map(
35 | partial(to_positive_negative_format,
36 | topk_as_positive=args.topk_as_positive,
37 | bottomk_as_negative=args.bottomk_as_negative),
38 | load_from_cache_file=args.world_size > 1,
39 | desc='to_positive_negative_format',
40 | remove_columns=['doc_ids', 'doc_scores']
41 | )
42 |
43 | if self.args.max_train_samples is not None:
44 | self.dataset = self.dataset.select(range(self.args.max_train_samples))
45 | # Log a few random samples from the training set:
46 | for index in random.sample(range(len(self.dataset)), 1):
47 | logger.info(f"Sample {index} of the training set: {self.dataset[index]}.")
48 |
49 | self.dataset.set_transform(self._transform_func)
50 |
51 | # use its state to decide which positives/negatives to sample
52 | self.trainer: Optional[Trainer] = None
53 |
54 | def __len__(self):
55 | return len(self.dataset)
56 |
57 | def __getitem__(self, idx):
58 | return self.dataset[idx]
59 |
60 | def _transform_func(self, examples: Dict[str, List]) -> Dict[str, List]:
61 | current_epoch = int(self.trainer.state.epoch or 0) if self.trainer is not None else 0
62 |
63 | examples = deepcopy(examples)
64 | # add some random negatives if not enough
65 | for idx in range(len(examples['query_id'])):
66 | while len(examples['negatives'][idx]['doc_id']) < self.negative_size:
67 | random_doc_id = str(random.randint(0, len(self.corpus) - 1))
68 | examples['negatives'][idx]['doc_id'].append(random_doc_id)
69 | examples['negatives'][idx]['score'].append(-100.)
70 |
71 | input_doc_ids = group_doc_ids(
72 | examples=examples,
73 | negative_size=self.negative_size,
74 | offset=current_epoch + self.args.seed
75 | )
76 | assert len(input_doc_ids) == len(examples['query']) * self.args.train_n_passages
77 |
78 | input_queries, input_docs = [], []
79 | for idx, doc_id in enumerate(input_doc_ids):
80 | input_docs.append(self.corpus[doc_id]['contents'].strip())
81 | # For reward model, the left side is the query + ground truth answer
82 | q_idx = idx // self.args.train_n_passages
83 | current_query: str = examples['query'][q_idx]
84 | answers, options = examples['answers'][q_idx], examples['options'][q_idx]
85 | if len(options) > 1:
86 | current_query += '\n' + options[ord(answers[0]) - ord('A')]
87 | # logger.info('current_query: %s', current_query)
88 | else:
89 | current_query += '\n' + random.choice(answers)
90 | input_queries.append(current_query)
91 |
92 | batch_dict = self.tokenizer(input_queries,
93 | text_pair=input_docs,
94 | max_length=self.args.reward_max_length,
95 | return_token_type_ids=False,
96 | padding=PaddingStrategy.DO_NOT_PAD,
97 | truncation=True)
98 |
99 | packed_batch_dict = {}
100 | for k in batch_dict:
101 | packed_batch_dict[k] = []
102 | assert len(examples['query']) * self.args.train_n_passages == len(batch_dict[k])
103 | for idx in range(len(examples['query'])):
104 | start = idx * self.args.train_n_passages
105 | packed_batch_dict[k].append(batch_dict[k][start:(start + self.args.train_n_passages)])
106 |
107 | return packed_batch_dict
108 |
109 |
110 | class CrossEncoderDataLoader:
111 |
112 | def __init__(self, args: Arguments, tokenizer: PreTrainedTokenizerFast):
113 | self.args = args
114 | self.tokenizer = tokenizer
115 | self.train_dataset = self._get_transformed_datasets()
116 |
117 | def set_trainer(self, trainer: Trainer):
118 | if self.train_dataset is not None:
119 | self.train_dataset.trainer = trainer
120 |
121 | def _get_transformed_datasets(self) -> CrossEncoderDataset:
122 | train_dataset = None
123 |
124 | if self.args.train_file is not None:
125 | train_input_files = get_input_files(self.args.train_file)
126 | logger.info("Train files: {}".format(train_input_files))
127 | train_dataset = CrossEncoderDataset(
128 | args=self.args,
129 | tokenizer=self.tokenizer,
130 | input_files=train_input_files,
131 | )
132 |
133 | if self.args.do_train:
134 | assert train_dataset is not None, "Training requires a train dataset"
135 |
136 | return train_dataset
137 |
--------------------------------------------------------------------------------
/src/loaders/loader_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 | from datasets import Dataset
3 |
4 | from config import Arguments
5 |
6 |
7 | def _slice_with_mod(elements: List, offset: int, cnt: int) -> List:
8 | return [elements[(offset + idx) % len(elements)] for idx in range(cnt)]
9 |
10 |
11 | def filter_invalid_examples(args: Arguments, dataset: Dataset) -> Dataset:
12 | def _filter_func(example: Dict) -> bool:
13 | if len(example['doc_ids']) <= args.topk_as_positive:
14 | return False
15 | if example['task_name'] in args.held_out_tasks:
16 | return False
17 |
18 | sorted_doc_scores = sorted(example['doc_scores'], reverse=True)
19 | if sorted_doc_scores[args.topk_as_positive - 1] <= -100.:
20 | return False
21 |
22 | return True
23 |
24 | return dataset.filter(
25 | _filter_func,
26 | load_from_cache_file=args.world_size > 1,
27 | )
28 |
29 |
30 | def group_doc_ids(examples: Dict[str, List],
31 | negative_size: int,
32 | offset: int) -> List[int]:
33 | pos_doc_ids: List[int] = []
34 | positives: List[Dict[str, List]] = examples['positives']
35 | for idx, ex_pos in enumerate(positives):
36 | all_pos_doc_ids = ex_pos['doc_id']
37 | cur_pos_doc_id = _slice_with_mod(all_pos_doc_ids, offset=offset, cnt=1)[0]
38 | pos_doc_ids.append(int(cur_pos_doc_id))
39 |
40 | neg_doc_ids: List[List[int]] = []
41 | negatives: List[Dict[str, List]] = examples['negatives']
42 | for ex_neg in negatives:
43 | cur_neg_doc_ids = _slice_with_mod(ex_neg['doc_id'],
44 | offset=offset * negative_size,
45 | cnt=negative_size)
46 | cur_neg_doc_ids = [int(doc_id) for doc_id in cur_neg_doc_ids]
47 | neg_doc_ids.append(cur_neg_doc_ids)
48 |
49 | assert len(pos_doc_ids) == len(neg_doc_ids), '{} != {}'.format(len(pos_doc_ids), len(neg_doc_ids))
50 | assert all(len(doc_ids) == negative_size for doc_ids in neg_doc_ids)
51 |
52 | input_doc_ids: List[int] = []
53 | for pos_doc_id, neg_ids in zip(pos_doc_ids, neg_doc_ids):
54 | input_doc_ids.append(pos_doc_id)
55 | input_doc_ids += neg_ids
56 |
57 | return input_doc_ids
58 |
--------------------------------------------------------------------------------
/src/logger_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | from transformers.trainer_callback import TrainerCallback
5 |
6 |
7 | def _setup_logger():
8 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
9 | logger = logging.getLogger()
10 | logger.setLevel(logging.INFO)
11 | # logger.setLevel(logging.DEBUG)
12 |
13 | console_handler = logging.StreamHandler()
14 | console_handler.setFormatter(log_format)
15 |
16 | data_dir = './data/'
17 | os.makedirs(data_dir, exist_ok=True)
18 | file_handler = logging.FileHandler('{}/log.txt'.format(data_dir))
19 | file_handler.setFormatter(log_format)
20 |
21 | logger.handlers = [console_handler, file_handler]
22 |
23 | return logger
24 |
25 |
26 | logger = _setup_logger()
27 |
28 |
29 | class LoggerCallback(TrainerCallback):
30 | def on_log(self, args, state, control, logs=None, **kwargs):
31 | _ = logs.pop("total_flos", None)
32 | if state.is_world_process_zero:
33 | logger.info(logs)
34 |
--------------------------------------------------------------------------------
/src/main_eval.py:
--------------------------------------------------------------------------------
1 | from datasets import Dataset, load_dataset, DownloadMode
2 | from transformers import HfArgumentParser
3 |
4 | from config import Arguments
5 | from logger_config import logger
6 | from llms import BaseLLM
7 | from model_utils import build_llm
8 | from data_utils import log_task_statistics
9 | from llm_evaluator import LLMEvaluator
10 | from inference.inference_utils import get_prompt_save_path
11 |
12 | parser = HfArgumentParser((Arguments,))
13 | args: Arguments = parser.parse_args_into_dataclasses()[0]
14 |
15 |
16 | def main():
17 | # columns: query_id / query / answers / task_name / input_prompt
18 | eval_dataset: Dataset = load_dataset(
19 | 'json', data_files=get_prompt_save_path(args), split='train',
20 | download_mode=DownloadMode.FORCE_REDOWNLOAD
21 | )
22 | if not args.llm_eval_tasks or args.llm_eval_tasks[0] == 'all':
23 | args.llm_eval_tasks = sorted(eval_dataset.unique('task_name'))
24 | logger.info('Eval all {} tasks'.format(len(args.llm_eval_tasks)))
25 |
26 |
27 | if args.process_index <= 0:
28 | log_task_statistics(eval_dataset, split=args.llm_eval_split)
29 | logger.info('{} tasks to evaluate: {}'.format(len(args.llm_eval_tasks), args.llm_eval_tasks))
30 |
31 |
32 | llm: BaseLLM = build_llm(args=args)
33 | llm.cuda(args.process_index)
34 |
35 | evaluator: LLMEvaluator = LLMEvaluator(args=args, llm=llm)
36 |
37 | for task_name in args.llm_eval_tasks:
38 | logger.info('Evaluating task: {}'.format(task_name))
39 | evaluator.eval_single_task(eval_dataset, task_name)
40 | if args.dry_run:
41 | break
42 |
43 | logger.info('Done')
44 |
45 |
46 | if __name__ == '__main__':
47 | main()
48 |
--------------------------------------------------------------------------------
/src/model_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from datasets import Dataset
4 |
5 | from llms import BaseLLM, GPT2, GPTNeo, Llama
6 | from evaluation import BaseEval, RandomEval, DenseEval, OpenaiEval, BM25Eval, BAAIEval
7 | from config import Arguments
8 | from logger_config import logger
9 |
10 |
11 | def build_llm(args: Arguments) -> BaseLLM:
12 | model_name_or_path: str = args.llm_model_name_or_path
13 | if 'gpt2' in model_name_or_path:
14 | if args.llm_max_input_length >= 1024:
15 | args.llm_max_input_length -= max(args.llm_max_decode_length, 128)
16 | logger.warning('GPT2 models cannot handle sequences longer than 1024. '
17 | 'set to {}'.format(args.llm_max_input_length))
18 | llm = GPT2(args=args, model_name_or_path=model_name_or_path)
19 | elif 'gpt-neo' in model_name_or_path:
20 | llm = GPTNeo(args=args, model_name_or_path=model_name_or_path)
21 | elif 'llama' in model_name_or_path:
22 | llm = Llama(args=args, model_name_or_path=model_name_or_path)
23 | else:
24 | raise ValueError('Invalid model name or path: {}'.format(model_name_or_path))
25 |
26 | return llm
27 |
28 |
29 | def build_eval_model(args: Arguments, corpus: Dataset) -> BaseEval:
30 | model_name_or_path: str = args.model_name_or_path
31 | if model_name_or_path == 'random':
32 | return RandomEval(args=args, corpus=corpus)
33 | elif model_name_or_path == 'BM25':
34 | return BM25Eval(args=args, corpus=corpus)
35 | elif model_name_or_path == 'OpenAI':
36 | return OpenaiEval(args=args, corpus=corpus)
37 | elif model_name_or_path == 'BAAI':
38 | return BAAIEval(args=args, corpus=corpus)
39 | elif 'llm-retriever-base' in model_name_or_path:
40 | # LLM-R
41 | return DenseEval(args=args, corpus=corpus)
42 | else:
43 | raise TypeError(f'retrieval method {model_name_or_path} not implemented')
44 |
45 |
46 | def parse_model_id(model_name_or_path: str) -> str:
47 | if model_name_or_path in ['random', 'BM25', 'Openai']:
48 | return model_name_or_path
49 | return os.path.basename(model_name_or_path.strip('/'))[-12:]
50 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .biencoder_model import BiencoderModel, BiencoderOutput
2 | from .cross_encoder_model import Reranker, RerankerForInference
3 | from .simple_encoder import SimpleEncoder
4 | from .simple_retriever import SimpleRetriever
5 |
--------------------------------------------------------------------------------
/src/models/biencoder_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from dataclasses import dataclass
7 | from typing import Optional, Dict, Tuple
8 | from torch import Tensor
9 | from transformers import PreTrainedModel, AutoModel
10 | from transformers.modeling_outputs import ModelOutput
11 |
12 | from config import Arguments
13 | from logger_config import logger
14 | from utils import dist_gather_tensor, select_grouped_indices, full_contrastive_scores_and_labels, pool
15 |
16 |
17 | @dataclass
18 | class BiencoderOutput(ModelOutput):
19 | q_reps: Optional[Tensor] = None
20 | p_reps: Optional[Tensor] = None
21 | loss: Optional[Tensor] = None
22 | labels: Optional[Tensor] = None
23 | scores: Optional[Tensor] = None
24 |
25 |
26 | class BiencoderModel(nn.Module):
27 | def __init__(self, args: Arguments,
28 | lm_q: PreTrainedModel,
29 | lm_p: PreTrainedModel):
30 | super().__init__()
31 | self.lm_q = lm_q
32 | self.lm_p = lm_p
33 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
34 | self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
35 | self.args = args
36 |
37 | from trainers import BiencoderTrainer
38 | self.trainer: Optional[BiencoderTrainer] = None
39 |
40 | self._freeze_position_embedding_if_needed(self.lm_q)
41 | self._freeze_position_embedding_if_needed(self.lm_p)
42 |
43 | def forward(self, batch_dict: Dict[str, Tensor]) -> BiencoderOutput:
44 | assert self.args.process_index >= 0
45 |
46 | scores, labels, q_reps, p_reps, all_scores, all_labels = self._compute_scores(batch_dict)
47 |
48 | start = self.args.process_index * q_reps.shape[0]
49 | group_indices = select_grouped_indices(scores=scores,
50 | group_size=self.args.train_n_passages,
51 | start=start * self.args.train_n_passages)
52 |
53 | if not self.args.do_kd_biencoder:
54 | # training biencoder from scratch
55 | loss = self.cross_entropy(scores, labels)
56 | else:
57 | # training biencoder with kd
58 | # batch_size x train_n_passage
59 | group_scores = torch.gather(input=scores, dim=1, index=group_indices)
60 | assert group_scores.shape[1] == self.args.train_n_passages
61 | group_log_scores = torch.log_softmax(group_scores, dim=-1)
62 | kd_log_target = torch.log_softmax(batch_dict['kd_labels'], dim=-1)
63 |
64 | kd_loss = self.kl_loss_fn(input=group_log_scores, target=kd_log_target)
65 | ce_loss = self.cross_entropy(scores, labels)
66 | loss = self.args.kd_cont_loss_weight * ce_loss + kd_loss
67 |
68 | total_n_psg = self.args.world_size * q_reps.shape[0] * self.args.train_n_passages
69 |
70 | return BiencoderOutput(loss=loss, q_reps=q_reps, p_reps=p_reps,
71 | labels=labels.contiguous(),
72 | scores=scores[:, :total_n_psg].contiguous())
73 |
74 | def _compute_scores(self, batch_dict: Dict[str, Tensor]) -> Tuple:
75 | embeds = self._encode(self.lm_p, batch_dict)
76 | batch_size = batch_dict['input_ids'].shape[0] // (self.args.train_n_passages + 1)
77 | q_reps = embeds[:batch_size]
78 | p_reps = embeds[batch_size:]
79 | assert p_reps.shape[0] == q_reps.shape[0] * self.args.train_n_passages
80 |
81 | all_q_reps = dist_gather_tensor(q_reps)
82 | all_p_reps = dist_gather_tensor(p_reps)
83 | assert all_p_reps.shape[0] == self.args.world_size * q_reps.shape[0] * self.args.train_n_passages
84 |
85 | all_scores, all_labels = full_contrastive_scores_and_labels(
86 | query=all_q_reps, key=all_p_reps,
87 | use_all_pairs=self.args.full_contrastive_loss
88 | )
89 | if self.args.l2_normalize:
90 | all_scores = all_scores / self.args.t
91 |
92 | start = self.args.process_index * q_reps.shape[0]
93 | local_query_indices = torch.arange(start, start + q_reps.shape[0], dtype=torch.long).to(q_reps.device)
94 | # batch_size x (world_size x batch_size x train_n_passage)
95 | scores = all_scores.index_select(dim=0, index=local_query_indices)
96 | labels = all_labels.index_select(dim=0, index=local_query_indices)
97 |
98 | return scores, labels, q_reps, p_reps, all_scores, all_labels
99 |
100 | def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch.Tensor]:
101 | if not input_dict:
102 | return None
103 | outputs = encoder(**{k: v for k, v in input_dict.items() if k not in ['labels', 'kd_labels']}, return_dict=True)
104 | embeds = pool(last_hidden_states=outputs.last_hidden_state,
105 | attention_mask=input_dict['attention_mask'],
106 | pool_type=self.args.pool_type)
107 | if self.args.l2_normalize:
108 | embeds = F.normalize(embeds, dim=-1)
109 | return embeds.contiguous()
110 |
111 | def _freeze_position_embedding_if_needed(self, model: nn.Module):
112 | if self.args.freeze_position_embedding:
113 | for name, param in model.named_parameters():
114 | if 'position_embeddings' in name:
115 | param.requires_grad = False
116 | logger.info('Freeze {}'.format(name))
117 |
118 | def gradient_checkpointing_enable(self):
119 | self.lm_q.gradient_checkpointing_enable()
120 |
121 | @classmethod
122 | def build(cls, args: Arguments, **hf_kwargs):
123 | if os.path.isdir(args.model_name_or_path):
124 | logger.info(f'loading shared model weight from {args.model_name_or_path}')
125 | lm_q = AutoModel.from_pretrained(args.model_name_or_path)
126 | lm_p = lm_q
127 |
128 | model = cls(args=args, lm_q=lm_q, lm_p=lm_p)
129 | return model
130 |
131 | def save(self, output_dir: str):
132 | self.lm_q.save_pretrained(output_dir)
133 |
--------------------------------------------------------------------------------
/src/models/cross_encoder_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from typing import Optional, Dict
5 | from transformers import (
6 | PreTrainedModel,
7 | AutoModelForSequenceClassification
8 | )
9 | from transformers.modeling_outputs import SequenceClassifierOutput
10 |
11 | from config import Arguments
12 |
13 |
14 | class Reranker(nn.Module):
15 | def __init__(self, hf_model: PreTrainedModel, args: Arguments):
16 | super().__init__()
17 | self.hf_model = hf_model
18 | self.args = args
19 |
20 | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
21 |
22 | def forward(self, batch: Dict[str, torch.Tensor]) -> SequenceClassifierOutput:
23 | input_batch_dict = {k: v for k, v in batch.items() if k != 'labels'}
24 |
25 | outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True)
26 | outputs.logits = outputs.logits.view(-1, self.args.train_n_passages)
27 | loss = self.cross_entropy(outputs.logits, batch['labels'])
28 | outputs.loss = loss
29 |
30 | return outputs
31 |
32 | def gradient_checkpointing_enable(self):
33 | self.hf_model.gradient_checkpointing_enable()
34 |
35 | @classmethod
36 | def from_pretrained(cls, all_args: Arguments, *args, **kwargs):
37 | hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
38 | return cls(hf_model, all_args)
39 |
40 | def save_pretrained(self, output_dir: str):
41 | self.hf_model.save_pretrained(output_dir)
42 |
43 |
44 | class RerankerForInference(nn.Module):
45 | def __init__(self, hf_model: Optional[PreTrainedModel] = None):
46 | super().__init__()
47 | self.hf_model = hf_model
48 | self.hf_model.eval()
49 |
50 | @torch.no_grad()
51 | def forward(self, batch) -> SequenceClassifierOutput:
52 | return self.hf_model(**batch)
53 |
54 | @classmethod
55 | def from_pretrained(cls, pretrained_model_name_or_path: str):
56 | hf_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)
57 | return cls(hf_model)
58 |
--------------------------------------------------------------------------------
/src/models/simple_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import tqdm
4 |
5 | from functools import partial
6 | from torch.utils.data import DataLoader
7 | from datasets import Dataset
8 | from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding, PreTrainedTokenizerFast, BatchEncoding
9 | from transformers.modeling_outputs import BaseModelOutput
10 | from typing import List, Dict
11 |
12 | from utils import pool, move_to_cuda
13 |
14 |
15 | def _transform_func(tokenizer: PreTrainedTokenizerFast,
16 | examples: Dict[str, List],
17 | prompt: str = None) -> BatchEncoding:
18 | if prompt:
19 | examples['input_texts'] = [prompt + t for t in examples['input_texts']]
20 | batch_dict = tokenizer(
21 | examples['input_texts'],
22 | max_length=256,
23 | return_token_type_ids=False,
24 | padding=True,
25 | truncation=True,
26 | )
27 |
28 | return batch_dict
29 |
30 |
31 | class SimpleEncoder(torch.nn.Module):
32 | def __init__(self, model_name_or_path: str,
33 | l2_normalize: bool = True,
34 | pool_type: str = 'avg',
35 | prompt: str = 'query: '):
36 | super().__init__()
37 | self.model_name_or_path = model_name_or_path
38 | self.encoder = AutoModel.from_pretrained(model_name_or_path)
39 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
40 | self.gpu_count = torch.cuda.device_count()
41 |
42 | self.l2_normalize = l2_normalize
43 | self.pool_type = pool_type
44 | self.prompt = prompt
45 | assert self.prompt in ['', 'query: ', 'passage: ']
46 |
47 | self.encoder.eval()
48 | self.encoder.cuda()
49 |
50 |
51 | self.gpu_count = 1
52 | # if self.gpu_count > 1:
53 | # self.encoder = torch.nn.DataParallel(self.encoder)
54 |
55 | @torch.no_grad()
56 | def encode(self, sentences: List[str], **kwargs) -> torch.Tensor:
57 | dataset: Dataset = Dataset.from_dict({'input_texts': sentences})
58 | dataset.set_transform(partial(_transform_func, self.tokenizer, prompt=self.prompt))
59 |
60 | data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8)
61 | data_loader = DataLoader(
62 | dataset,
63 | batch_size=128 * self.gpu_count,
64 | shuffle=False,
65 | drop_last=False,
66 | num_workers=2,
67 | collate_fn=data_collator,
68 | pin_memory=True)
69 |
70 | encoded_embeds = []
71 | for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10, disable=len(sentences) < 128):
72 | batch_dict = move_to_cuda(batch_dict)
73 |
74 | with torch.cuda.amp.autocast():
75 | outputs: BaseModelOutput = self.encoder(**batch_dict)
76 | embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], self.pool_type)
77 | if self.l2_normalize:
78 | embeds = F.normalize(embeds, p=2, dim=-1)
79 | encoded_embeds.append(embeds.cpu())
80 |
81 | return torch.cat(encoded_embeds, dim=0)
82 |
--------------------------------------------------------------------------------
/src/models/simple_retriever.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import torch
4 |
5 | from typing import List, Dict, Union
6 | from datasets import Dataset
7 | from collections import defaultdict
8 |
9 | from models.simple_encoder import SimpleEncoder
10 | from logger_config import logger
11 |
12 |
13 | def _sharded_search_topk(
14 | query_embeds: torch.Tensor, top_k: int,
15 | shard_embed: torch.Tensor, shard_idx: int,
16 | idx_offset: int) -> Dict[int, List]:
17 | query_idx_to_topk: Dict[int, List] = defaultdict(list)
18 | search_batch_size = 256
19 | query_indices = list(range(query_embeds.shape[0]))
20 |
21 | for start in tqdm.tqdm(range(0, query_embeds.shape[0], search_batch_size),
22 | desc="search shard {}".format(shard_idx),
23 | mininterval=5):
24 | batch_query_embed = query_embeds[start:(start + search_batch_size)]
25 | batch_query_indices = query_indices[start:(start + search_batch_size)]
26 | batch_score = torch.mm(batch_query_embed, shard_embed.t())
27 | batch_sorted_score, batch_sorted_indices = torch.topk(batch_score, k=top_k, dim=-1, largest=True)
28 | for batch_idx, query_idx in enumerate(batch_query_indices):
29 | cur_scores = batch_sorted_score[batch_idx].cpu().tolist()
30 | cur_indices = [str(idx + idx_offset) for idx in batch_sorted_indices[batch_idx].cpu().tolist()]
31 | query_idx_to_topk[query_idx] += list(zip(cur_scores, cur_indices))
32 | query_idx_to_topk[query_idx] = sorted(query_idx_to_topk[query_idx], key=lambda t: -t[0])[:top_k]
33 |
34 | return query_idx_to_topk
35 |
36 |
37 | class SimpleRetriever:
38 |
39 | def __init__(self, encoder: SimpleEncoder,
40 | corpus: Union[Dataset, List[str]],
41 | cache_dir: str = None):
42 | self.encoder = encoder
43 |
44 | # Encode the "contents" column of the corpus
45 | if isinstance(corpus, List):
46 | corpus = Dataset.from_dict({'contents': corpus})
47 | self.corpus: Dataset = corpus
48 | logger.info(f"Corpus size: {len(self.corpus)}")
49 |
50 | self.cache_dir = cache_dir or 'tmp-{}/'.format(len(corpus))
51 | os.makedirs(self.cache_dir, exist_ok=True)
52 | logger.info(f"Cache dir: {self.cache_dir}")
53 | self.encode_shard_size = 2_000_000
54 |
55 | def search_topk(self, queries: List[str], top_k: int = 10) -> Dict[int, List]:
56 | # encode the corpus or load from cache if it already exists
57 | self._encode_corpus_if_necessary(shard_size=self.encode_shard_size)
58 |
59 | # encode the queries
60 | query_embeds = self._encode_queries(queries)
61 | if torch.cuda.is_available():
62 | query_embeds = query_embeds.cuda()
63 |
64 | # search the top-k results
65 | query_idx_to_topk: Dict[int, List] = defaultdict(list)
66 | num_shards = (len(self.corpus) + self.encode_shard_size - 1) // self.encode_shard_size
67 | idx_offset = 0
68 | for shard_idx in range(num_shards):
69 | out_path: str = self._get_out_path(shard_idx)
70 | shard_embeds = torch.load(out_path, map_location=lambda storage, loc: storage)
71 | shard_embeds = shard_embeds.to(query_embeds.device)
72 | shard_query_idx_to_topk = _sharded_search_topk(
73 | query_embeds=query_embeds,
74 | top_k=top_k,
75 | shard_embed=shard_embeds,
76 | shard_idx=shard_idx,
77 | idx_offset=idx_offset
78 | )
79 | for query_idx, shard_topk in shard_query_idx_to_topk.items():
80 | query_idx_to_topk[query_idx] += shard_topk
81 | query_idx_to_topk[query_idx] = sorted(query_idx_to_topk[query_idx], key=lambda t: -t[0])[:top_k]
82 |
83 | idx_offset += shard_embeds.shape[0]
84 |
85 | return query_idx_to_topk
86 |
87 | def encode_corpus(self):
88 | self._encode_corpus_if_necessary(shard_size=self.encode_shard_size)
89 | logger.info('Done encoding corpus')
90 |
91 | def _get_out_path(self, shard_idx: int) -> str:
92 | return '{}/shard_{}'.format(self.cache_dir, shard_idx)
93 |
94 | def _encode_corpus_if_necessary(self, shard_size: int):
95 | num_shards = (len(self.corpus) + shard_size - 1) // shard_size
96 | num_examples = 0
97 | for shard_idx in range(num_shards):
98 | out_path: str = self._get_out_path(shard_idx)
99 | if os.path.exists(out_path):
100 | logger.info('{} already exists, will skip encoding'.format(out_path))
101 | num_examples += len(torch.load(out_path, map_location=lambda storage, loc: storage))
102 | continue
103 | shard_dataset: Dataset = self.corpus.shard(
104 | num_shards=num_shards,
105 | index=shard_idx,
106 | contiguous=True
107 | )
108 | shard_embeds: torch.Tensor = self.encoder.encode(
109 | sentences=shard_dataset['contents']
110 | )
111 |
112 | num_examples += shard_embeds.shape[0]
113 | logger.info('Saving shard {} ({} examples) to {}'.format(shard_idx, len(shard_dataset), out_path))
114 | torch.save(shard_embeds, out_path)
115 |
116 | assert num_examples == len(self.corpus), \
117 | f"Number of examples in the corpus ({len(self.corpus)}) " \
118 | f"does not match the number of examples in the shards ({num_examples})"
119 |
120 | def _encode_queries(self, queries: List[str]) -> torch.Tensor:
121 | return self.encoder.encode(
122 | sentences=queries
123 | )
124 |
--------------------------------------------------------------------------------
/src/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Optional
2 |
3 |
4 | def to_letter(key: Union[str, int]) -> str:
5 | key = str(key).upper().strip()
6 | num_to_letter = {"0": "A", "1": "B", "2": "C", "3": "D"}
7 | assert key in num_to_letter or key in ['A', 'B', 'C', 'D'], f'Unknown answer key: {key}'
8 | return num_to_letter.get(key, key)
9 |
10 |
11 | def format_options(options: List[str]) -> str:
12 | assert len(options) <= 4, f'Number of options should be less than 4, but got {len(options)}'
13 | res = 'OPTIONS: '
14 | for letter, option in zip(['A', 'B', 'C', 'D'], options):
15 | res += f'{letter}) {option} '
16 |
17 | return res.strip()
18 |
19 |
20 | # Based on https://github.com/microsoft/LMOps/blob/main/uprise/DPR/dpr/utils/tasks.py
21 |
22 | TASK_TYPE_TO_TASK_NAME = {
23 | 'close_qa': ['natural_questions', 'arc_c', 'arc_e'],
24 | 'common_reason': ['copa', 'piqa', 'hellaswag'],
25 | 'coreference': ['winogrande', 'wsc', 'wsc273'],
26 | 'nli': ['rte', 'mnli', 'mnli_m', 'mnli_mm', 'qnli', 'snli'],
27 | 'paraphrase': ['mrpc', 'paws', 'qqp'],
28 | 'reading': ['multirc', 'openbookqa', 'squad_v1', 'boolq'],
29 | 'sentiment': ['yelp', 'sentiment140', 'sst2'],
30 | 'struct2text': ['common_gen', 'e2e_nlg', 'dart'],
31 | 'summarize': ['aeslc', 'ag_news', 'gigaword'],
32 | }
33 |
34 |
35 | class App:
36 | def __init__(self):
37 | self.cls_dic = {}
38 |
39 | def add(self, key):
40 | def adder(cls):
41 | self.cls_dic[key] = cls
42 | return cls
43 |
44 | return adder
45 |
46 |
47 | task_map = App()
48 |
49 |
50 | from .base_task import BaseTask
51 | from .aeslc import Aeslc
52 | from .agnews import Ag_news
53 | from .arc import Arc_c, Arc_e
54 | from .boolq import Boolq
55 | from .common_gen import Common_gen
56 | from .copa import Copa
57 | from .dart import Dart
58 | from .e2e_nlg import E2e_nlg
59 | from .gigaword import Gigaword
60 | from .hellaswag import Hellaswag
61 | from .mnli import Mnli, Mnli_m, Mnli_mm
62 | from .mrpc import Mrpc
63 | from .multirc import Multirc
64 | from .nq import Natural_questions
65 | from .openbookqa import Openbookqa
66 | from .paws import Paws
67 | from .piqa import Piqa
68 | from .qnli import Qnli
69 | from .qqp import Qqp
70 | from .rte import Rte
71 | from .sentiment140 import Sentiment140
72 | from .snli import Snli
73 | from .squad_v1 import Squad_v1
74 | from .sst2 import Sst2
75 | from .winogrande import Winogrande
76 | from .wsc import Wsc
77 | from .wsc273 import Wsc273
78 | from .yelp import Yelp
79 |
80 |
81 | def get_metric_name_by_task_name(task_name: str) -> str:
82 | assert task_name in task_map.cls_dic, f'Unknown task name: {task_name}'
83 | return task_map.cls_dic[task_name]().metric_name
84 |
85 |
86 | def get_possible_answers_by_task_name(task_name: str) -> Optional[List[str]]:
87 | assert task_name in task_map.cls_dic, f'Unknown task name: {task_name}'
88 | return task_map.cls_dic[task_name]().possible_answers
89 |
90 |
91 | def parse_decoded_text_by_task(decoded_text: str, task_name: str) -> str:
92 | # TODO: maybe add some task-specific logics here
93 | return decoded_text.strip().split('\n')[0]
94 |
--------------------------------------------------------------------------------
/src/tasks/aeslc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("aeslc")
9 | class Aeslc(BaseTask):
10 |
11 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
12 | split = split if split == 'train' else 'test'
13 | # For some reason, huggingface reports "checksum mismatch", so we ignore the checksum for now
14 | dataset: Dataset = load_dataset('aeslc', split=split, ignore_verifications=True)
15 | dataset = dataset.rename_column('email_body', 'body')
16 | dataset = dataset.rename_column('subject_line', 'subject')
17 |
18 | def _remove_newlines(example: Dict) -> Dict:
19 | example['body'] = ' '.join(example['body'].split())
20 | example['subject'] = ' '.join(example['subject'].split())
21 | return example
22 |
23 | dataset = dataset.map(_remove_newlines, desc='remove newlines')
24 |
25 | # filter logic from uprise. For FLAN, it filters out empty examples
26 | def _filter_func(example: Dict) -> bool:
27 | return 0 < len(example['body'].split()) <= 256 and 0 < len(example['subject'].split()) <= 256
28 |
29 | dataset = dataset.filter(_filter_func)
30 |
31 | return dataset
32 |
33 | @property
34 | def templates(self) -> List[Tuple[str, str]]:
35 | return [
36 | ("What is the subject line for this email? {body}", "{subject}"),
37 | ("Write a subject line for this message: {body}", "{subject}"),
38 | ("{body} Write a subject line for this email.", "{subject}"),
39 | ("Here is an email: {body} What is a potential subject line for this email?", "{subject}"),
40 | ("{body} Propose a subject line for this email?", "{subject}"),
41 | ("This is the content of an email: {body} What was the subject line for this email?", "{subject}"),
42 | ("This is an email {body} What is the subject of this email?", "{subject}"),
43 | ("{body} Generate a subject line for this email.", "{subject}"),
44 | ]
45 |
46 | @property
47 | def possible_answers(self) -> Optional[List[str]]:
48 | return None
49 |
50 | @property
51 | def metric_name(self) -> str:
52 | return 'rouge'
53 |
54 | @property
55 | def task_name(self) -> str:
56 | return 'aeslc'
57 |
58 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
59 | return example['subject']
60 |
61 |
--------------------------------------------------------------------------------
/src/tasks/agnews.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("ag_news")
9 | class Ag_news(BaseTask):
10 |
11 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
12 | split = split if split == 'train' else 'test'
13 | dataset = load_dataset('ag_news', split=split)
14 | return dataset
15 |
16 | @property
17 | def templates(self) -> List[Tuple[str, str]]:
18 | return [
19 | ("\"{text}\" What is this text about? World, Sports, Business, or Technology?", "{answer}"),
20 | ("\"{text}\" Which topic is this article about? World, Sports, Business, or Technology?", "{answer}"),
21 | ("\"{text}\" Which is the best summary of this article? World, Sports, Business, or Technology?",
22 | "{answer}"),
23 | ("\"{text}\" What is this text about? World, Sports, Business, or Technology?", "{answer}"),
24 | (
25 | "\"{text}\" What best summarizes the content of the above article? World, Sports, Business, or Technology?",
26 | "{answer}"),
27 | ("Which is this about? \"{text}\" World, Sports, Business, or Technology?", "{answer}"),
28 | ("Which is an appropriate title for this article? \"{text}\" World, Sports, Business, or Technology?",
29 | "{answer}"),
30 | ("Select the topic that this about: \"{text}\" World, Sports, Business, or Technology?", "{answer}"),
31 | ]
32 |
33 | @property
34 | def possible_answers(self) -> Optional[List[str]]:
35 | return ['World', 'Sports', 'Business', 'Technology']
36 |
37 | @property
38 | def metric_name(self) -> str:
39 | return 'simple_accuracy'
40 |
41 | @property
42 | def task_name(self) -> str:
43 | return 'ag_news'
44 |
--------------------------------------------------------------------------------
/src/tasks/arc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map, to_letter
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("arc_c")
9 | class Arc_c(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'test'
12 | dataset = load_dataset('ai2_arc', 'ARC-Challenge', split=split)
13 |
14 | # Both FLAN & uprise & structured prompting have this filter logic
15 | dataset = dataset.filter(lambda ex: len(ex['choices']['text']) == 4)
16 |
17 | def _map_func(ex: Dict) -> Dict:
18 | if ex['answerKey'] not in ['A', 'B', 'C', 'D']:
19 | ex["answerKey"] = to_letter(int(ex['answerKey']) - 1)
20 | ex['options'] = ex['choices']['text']
21 | return ex
22 |
23 | dataset = dataset.map(_map_func)
24 |
25 | return dataset
26 |
27 | @property
28 | def templates(self) -> List[Tuple[str, str]]:
29 | return [
30 | ("{question}", "{answer}"),
31 | ("Question: {question} Answer:", "{answer}"),
32 | ("Question: {question} What is the correct answer to the question from the following choices?", "{answer}"),
33 | ("Q: {question} What is the correct answer to this question?", "{answer}"),
34 | ("What is the answer? {question}", "{answer}"),
35 | ("Answer the question {question}", "{answer}"),
36 | ("{question} Pick the answer from these options.", "{answer}"),
37 | ]
38 |
39 | @property
40 | def possible_answers(self) -> Optional[List[str]]:
41 | return ['A', 'B', 'C', 'D']
42 |
43 | @property
44 | def metric_name(self) -> str:
45 | return 'simple_accuracy'
46 |
47 | @property
48 | def task_name(self) -> str:
49 | return 'arc_c'
50 |
51 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
52 | return example['answerKey']
53 |
54 |
55 | @task_map.add("arc_e")
56 | class Arc_e(Arc_c):
57 |
58 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
59 | split = split if split == 'train' else 'test'
60 | dataset = load_dataset('ai2_arc', 'ARC-Easy', split=split)
61 | dataset = dataset.filter(lambda ex: len(ex['choices']['text']) == 4)
62 |
63 | def _map_func(ex: Dict) -> Dict:
64 | if ex['answerKey'] not in ['A', 'B', 'C', 'D']:
65 | ex["answerKey"] = to_letter(int(ex['answerKey']) - 1)
66 | ex['options'] = ex['choices']['text']
67 | return ex
68 |
69 | dataset = dataset.map(_map_func)
70 | return dataset
71 |
72 | @property
73 | def task_name(self) -> str:
74 | return 'arc_e'
75 |
--------------------------------------------------------------------------------
/src/tasks/base_task.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from typing import List, Optional, Tuple, Dict, Union
4 | from datasets import Dataset
5 |
6 | from logger_config import logger
7 |
8 |
9 | class BaseTask(object):
10 |
11 | def __init__(self, template_idx: int = 0, **kwargs):
12 | self.template_idx = template_idx
13 |
14 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
15 | raise NotImplementedError
16 |
17 | def get_task_data(self, split: str) -> Optional[Dataset]:
18 | # columns: query_id / query / answers / task_name
19 | dataset = self._load_raw_data(split)
20 | if not dataset:
21 | return None
22 |
23 | logger.info('Load dataset: {}, split: {}'.format(self.task_name, split))
24 | dataset = dataset.map(self.map_single, num_proc=4)
25 | dataset = dataset.add_column(
26 | 'query_id', ['{}_{}_{}'.format(self.task_name, split, idx) for idx in range(len(dataset))]
27 | )
28 | dataset = dataset.remove_columns(
29 | column_names=[col for col in dataset.column_names
30 | if col not in ['query_id', 'query', 'answers', 'options', 'task_name']]
31 | )
32 |
33 | return dataset
34 |
35 | def get_corpus(self) -> Optional[Dataset]:
36 | # columns: contents / task_name
37 | corpus = self.get_task_data(split='train')
38 | if not corpus:
39 | return None
40 |
41 | def _map_func(example: Dict) -> Dict:
42 | answer = example['answers'][0]
43 | if len(example['options']) > 1:
44 | # multiple-choice tasks
45 | option_idx = self.possible_answers.index(answer)
46 | answer = example['options'][option_idx]
47 | return {
48 | 'contents': '\n'.join([example['query'], answer]),
49 | }
50 |
51 | corpus = corpus.map(
52 | _map_func, num_proc=4,
53 | remove_columns=[col for col in corpus.column_names if col not in ['contents', 'task_name']],
54 | desc='{} corpus'.format(self.task_name)
55 | )
56 |
57 | return corpus
58 |
59 | def get_template(self) -> Tuple[str, str]:
60 | return self.templates[self.template_idx % len(self.templates)]
61 |
62 | def map_single(self, example: Dict) -> Dict:
63 | # ("If \"{premise}\", can we conclude that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"),
64 | query_template, answer_template = self.get_template()
65 |
66 | # find the key in {key} format using regular expression
67 | query_keys = re.findall(r'\{(\w+)\}', query_template)
68 | answer_keys = re.findall(r'\{(\w+)\}', answer_template)
69 |
70 | # replace the key with the value in the example
71 | query: str = query_template.format(**{key: example[key] for key in query_keys})
72 | assert len(answer_keys) == 1, "Only one answer key is allowed"
73 | answer_key = answer_keys[0]
74 | del answer_keys
75 |
76 | example[answer_key]: Union[str, List[str]] = self.get_answer(example)
77 | if isinstance(example[answer_key], str):
78 | answers: List[str] = [answer_template.format(**{answer_key: example[answer_key]})]
79 | elif isinstance(example[answer_key], list):
80 | answers: List[str] = [answer_template.format(**{answer_key: ans}) for ans in example[answer_key]]
81 | else:
82 | raise ValueError(f"Unknown answer type: {example[answer_key]}")
83 |
84 | return {
85 | 'query': query,
86 | 'options': example.get('options', ['']),
87 | 'answers': answers,
88 | 'task_name': self.task_name,
89 | }
90 |
91 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
92 | # Many tasks need to override this default implementation
93 | assert int(example['label']) >= 0, "label must be non-negative"
94 | return self.possible_answers[int(example['label'])]
95 |
96 | @property
97 | def templates(self) -> List[Tuple[str, str]]:
98 | raise NotImplementedError
99 |
100 | @property
101 | def possible_answers(self) -> Optional[List[str]]:
102 | raise NotImplementedError
103 |
104 | @property
105 | def metric_name(self) -> str:
106 | raise NotImplementedError
107 |
108 | @property
109 | def task_name(self) -> str:
110 | raise NotImplementedError
111 |
--------------------------------------------------------------------------------
/src/tasks/boolq.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("boolq")
9 | class Boolq(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('super_glue', 'boolq', split=split)
13 | dataset = dataset.rename_column('passage', 'text')
14 | return dataset
15 |
16 | @property
17 | def templates(self) -> List[Tuple[str, str]]:
18 | return [
19 | ("{text} Can we conclude that {question}?", "{answer}"),
20 | ("{text} Is it true that {question}?", "{answer}"),
21 | ("{text} {question}?", "{answer}"),
22 | ("Text: {text} Question: {question}?", "{answer}"),
23 | ("{text} What's the best answer to this question: {question}?", "{answer}"),
24 | ("{text} Based on the above text, what's the best answer to this question: {question}?", "{answer}"),
25 | ("{text} Answer this question, making sure that the answer is supposed by the text: {question}?",
26 | "{answer}"),
27 | ("{text} Is the following statement correct based on the text {question}", "{answer}"),
28 | ("{text} Is this statement correct \"{question}\"?", "{answer}"),
29 | ("Is it true that {question} based on the following text? {text}", "{answer}"),
30 | ]
31 |
32 | @property
33 | def possible_answers(self) -> Optional[List[str]]:
34 | return ['No', 'Yes']
35 |
36 | @property
37 | def metric_name(self) -> str:
38 | return 'simple_accuracy'
39 |
40 | @property
41 | def task_name(self) -> str:
42 | return 'boolq'
43 |
--------------------------------------------------------------------------------
/src/tasks/common_gen.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("common_gen")
9 | class Common_gen(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('common_gen', split=split)
13 | dataset = dataset.map(lambda ex: {'concepts': ", ".join(ex["concepts"])})
14 | return dataset
15 |
16 | @property
17 | def templates(self) -> List[Tuple[str, str]]:
18 | return [
19 | ("Concepts: {concepts}. Write a sentence that includes all these words.", "{target}"),
20 | ("Keywords: {concepts}. What is a sentence that includes all these keywords?", "{target}"),
21 | ("Here are some concepts: {concepts}. What is a sentence about these concepts?", "{target}"),
22 | ("Produce a sentence which mentions all of these concepts: {concepts}.", "{target}"),
23 | ("Write a sentence about the following things: {concepts}.", "{target}"),
24 | ("Generate a sentence that includes all the following words: {concepts}.", "{target}"),
25 | ]
26 |
27 | @property
28 | def possible_answers(self) -> Optional[List[str]]:
29 | return None
30 |
31 | @property
32 | def metric_name(self) -> str:
33 | return 'rouge'
34 |
35 | @property
36 | def task_name(self) -> str:
37 | return 'common_gen'
38 |
39 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
40 | return example['target']
41 |
--------------------------------------------------------------------------------
/src/tasks/copa.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map, to_letter
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("copa")
9 | class Copa(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('super_glue', 'copa', split=split)
13 |
14 | def _map_func(ex: Dict) -> Dict:
15 | ex['options'] = [ex['choice1'], ex['choice2']]
16 | return ex
17 |
18 | dataset = dataset.map(_map_func)
19 |
20 | return dataset
21 |
22 | @property
23 | def templates(self) -> List[Tuple[str, str]]:
24 | # question is either "cause" or "effect"
25 | return [
26 | ("\"{premise}\" What is the {question}?", "{answer}"),
27 | ("Here is a premise: \"{premise}\" What is the {question}?", "{answer}"),
28 | ("\"{premise}\" What is the {question} of the preceding sentence?", "{answer}"),
29 | ("\"{premise}\" What is a plausible {question}?", "{answer}"),
30 | ("Based on the following sentence, what is the {question}? \"{premise}\"", "{answer}"),
31 | ("\"{premise}\" {question}:", "{answer}"),
32 | ("What is the {question} of the following sentence? \"{premise}\"", "{answer}"),
33 | ("Answer the following question about this sentence: \"{premise}\" What is the {question}?", "{answer}"),
34 | ]
35 |
36 | @property
37 | def possible_answers(self) -> Optional[List[str]]:
38 | return ['A', 'B']
39 |
40 | @property
41 | def metric_name(self) -> str:
42 | return 'simple_accuracy'
43 |
44 | @property
45 | def task_name(self) -> str:
46 | return 'copa'
47 |
48 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
49 | return to_letter(str(example['label']))
50 |
--------------------------------------------------------------------------------
/src/tasks/dart.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from typing import Optional, List, Tuple, Dict, Union
4 | from datasets import load_dataset, Dataset
5 |
6 | from tasks import task_map
7 | from tasks.base_task import BaseTask
8 |
9 |
10 | @task_map.add("dart")
11 | class Dart(BaseTask):
12 | '''
13 | https://huggingface.co/datasets/GEM/dart
14 | '''
15 |
16 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
17 | split = split if split == 'train' else 'validation'
18 | dataset = load_dataset('GEM/dart', split=split)
19 |
20 | def _map_func(ex: Dict) -> Dict:
21 | tripleset = "; ".join([", ".join(triplet) for triplet in ex["tripleset"]])
22 | # Get rid of some undesirable cells like "[TABLECONTEXT]", "[TITLE]"
23 | tripleset = re.sub(r'\[(.*?)\]', '', tripleset)
24 | return {
25 | 'tripleset': tripleset
26 | }
27 |
28 | dataset = dataset.map(_map_func)
29 |
30 | return dataset
31 |
32 | @property
33 | def templates(self) -> List[Tuple[str, str]]:
34 | return [
35 | ("Triple: {tripleset} What is a sentence that describes this triple?", "{target}"),
36 | ("Data: {tripleset} What would a sentence about this data be like?", "{target}"),
37 | ("Generate an approximately fifteen-word sentence that describes all this data: {tripleset}", "{target}"),
38 | ("Here is some data: {tripleset}. Write a sentence that describes this data", "{target}"),
39 | ("This is some data: {tripleset}. Generate a detailed description of this data", "{target}"),
40 | ("Generate a sentence about this data: {tripleset}", "{target}"),
41 | ("Write a sentence that about [{tripleset}].", "{target}"),
42 | ("Produce a long descriptive sentence that uses all these words: {tripleset}", "{target}"),
43 | ]
44 |
45 | @property
46 | def possible_answers(self) -> Optional[List[str]]:
47 | return None
48 |
49 | @property
50 | def metric_name(self) -> str:
51 | return 'rouge'
52 |
53 | @property
54 | def task_name(self) -> str:
55 | return 'dart'
56 |
57 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
58 | return example['target']
59 |
--------------------------------------------------------------------------------
/src/tasks/e2e_nlg.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from typing import Optional, List, Tuple, Dict, Union
4 | from datasets import load_dataset, Dataset
5 |
6 | from tasks import task_map
7 | from tasks.base_task import BaseTask
8 |
9 |
10 | @task_map.add("e2e_nlg")
11 | class E2e_nlg(BaseTask):
12 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
13 | split = split if split == 'train' else 'test'
14 | dataset = load_dataset('GEM/e2e_nlg', split=split)
15 |
16 | def _map_func(ex: Dict) -> Dict:
17 | meaning_representation = re.sub(r'\[', ' = ', ex['meaning_representation'])
18 | meaning_representation = re.sub(r'\]', '', meaning_representation)
19 | return {
20 | 'meaning_representation': meaning_representation
21 | }
22 |
23 | dataset = dataset.map(_map_func)
24 |
25 | return dataset
26 |
27 | @property
28 | def templates(self) -> List[Tuple[str, str]]:
29 | return [
30 | ("Attributes: {meaning_representation}. Produce a detailed sentence about this restaurant.", "{target}"),
31 | ("Data: {meaning_representation}. Can you generate a sentence about this data?", "{target}"),
32 | ("Data: {meaning_representation}. What is a sentence that describe this data?", "{target}"),
33 | (
34 | "Here are some keywords about a restaurant: {meaning_representation}. Write a sentence that describes the following attributes of a restaurant.",
35 | "{target}"),
36 | (
37 | "Here is some data about a restaurant: {meaning_representation}. Write a sentence that includes the following data about a restaurant.",
38 | "{target}"),
39 | ("Sentence: {meaning_representation}. Can you represent the content in this sentence in data form?",
40 | "{target}"),
41 | ("Write a sentence about a restaurant with all the following attributes: {meaning_representation}.",
42 | "{target}"),
43 | ("Write a sentence that is about a restaurant with all the following properties: {meaning_representation}.",
44 | "{target}"),
45 | ("Produce a detailed sentence about a restaurant using the following words: {meaning_representation}.",
46 | "{target}"),
47 | ("Generate a descriptive sentence about a restaurant using the following words: {meaning_representation}.",
48 | "{target}"),
49 | ]
50 |
51 | @property
52 | def possible_answers(self) -> Optional[List[str]]:
53 | return None
54 |
55 | @property
56 | def metric_name(self) -> str:
57 | return 'rouge'
58 |
59 | @property
60 | def task_name(self) -> str:
61 | return 'e2e_nlg'
62 |
63 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
64 | return example['target']
65 |
--------------------------------------------------------------------------------
/src/tasks/gigaword.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("gigaword")
9 | class Gigaword(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'test'
12 | dataset = load_dataset('gigaword', split=split)
13 |
14 | def _filter_func(ex: Dict) -> bool:
15 | text = ''.join([ex['document'], ex['summary']])
16 | no_unk = 'UNK' not in text
17 | no_hashtag = '#' not in text
18 | return no_unk and no_hashtag
19 |
20 | dataset = dataset.filter(_filter_func)
21 | dataset = dataset.rename_column('document', 'text')
22 |
23 | return dataset
24 |
25 | @property
26 | def templates(self) -> List[Tuple[str, str]]:
27 | return [
28 | ("Write a short summary for this text: {text}", "{summary}"),
29 | ("Briefly summarize this sentence: {text}", "{summary}"),
30 | ("Generate a short summary this sentence: {text}", "{summary}"),
31 | ("What is a shorter version of this: {text}", "{summary}"),
32 | ("{text} Write a brief summary in a sentence or less", "{summary}"),
33 | ("{text} What is a very short summary of the above text?", "{summary}"),
34 | ("{text} Summarize the aforementioned text in a single phrase.", "{summary}"),
35 | ("{text} Can you generate a short summary of the above paragraph?", "{summary}"),
36 | ]
37 |
38 | @property
39 | def possible_answers(self) -> Optional[List[str]]:
40 | return None
41 |
42 | @property
43 | def metric_name(self) -> str:
44 | return 'rouge'
45 |
46 | @property
47 | def task_name(self) -> str:
48 | return 'gigaword'
49 |
50 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
51 | return example['summary']
52 |
--------------------------------------------------------------------------------
/src/tasks/hellaswag.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from typing import Optional, List, Tuple, Dict, Union
4 | from datasets import load_dataset, Dataset
5 |
6 | from tasks import task_map, to_letter
7 | from tasks.base_task import BaseTask
8 |
9 |
10 | @task_map.add("hellaswag")
11 | class Hellaswag(BaseTask):
12 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
13 | split = split if split == 'train' else 'validation'
14 | dataset = load_dataset('hellaswag', split=split)
15 |
16 | def _map_func(ex: Dict) -> Dict:
17 | ex['ctx'] = re.sub(r'\[.*?\]\s', '', ex['ctx']).strip()
18 | ex['options'] = [re.sub(r'\[.*?\]\s', '', option) for option in ex['endings']]
19 | return ex
20 |
21 | dataset = dataset.map(_map_func)
22 | dataset = dataset.rename_column('ctx', 'context')
23 |
24 | return dataset
25 |
26 | @property
27 | def templates(self) -> List[Tuple[str, str]]:
28 | return [
29 | ("What happens next in this paragraph? {context}", "{answer}"),
30 | ("Continue writing the next sentence in this paragraph: {context}", "{answer}"),
31 | ("Continue writing the next sentence. {context}", "{answer}"),
32 | ("This is a test of commonsense. Complete the next sentence: {context}", "{answer}"),
33 | ("Write the next sentence in this paragraph: {context}", "{answer}"),
34 | ("How does the next paragraph end? {context}", "{answer}"),
35 | ("What most naturally follows? {context}", "{answer}"),
36 | ("What happens next? {context}", "{answer}"),
37 | ("What is the most logical next event? {context}", "{answer}"),
38 | ("Write the next sentence in the following story. {context}", "{answer}"),
39 | ]
40 |
41 | @property
42 | def possible_answers(self) -> Optional[List[str]]:
43 | return ['A', 'B', 'C', 'D']
44 |
45 | @property
46 | def metric_name(self) -> str:
47 | return 'simple_accuracy'
48 |
49 | @property
50 | def task_name(self) -> str:
51 | return 'hellaswag'
52 |
53 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
54 | return to_letter(example['label'])
55 |
--------------------------------------------------------------------------------
/src/tasks/mnli.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("mnli")
9 | class Mnli(BaseTask):
10 |
11 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
12 | if split != 'train':
13 | return None
14 | dataset: Dataset = load_dataset('glue', 'mnli', split=split)
15 | return dataset
16 |
17 | @property
18 | def templates(self) -> List[Tuple[str, str]]:
19 | return [
20 | (
21 | "Premise: \"{premise}\" Hypothesis: \"{hypothesis}\" Does the premise entail the hypothesis? Yes, No, or Maybe?",
22 | "{answer}"),
23 | (
24 | "Premise: \"{premise}\" Hypothesis: \"{hypothesis}\" Is the hypothesis entailed by the premise? Yes, No, or Maybe?",
25 | "{answer}"),
26 | (
27 | "Here is a premise: \"{premise}\" Here is a hypothesis: \"{hypothesis}\" Is it possible to conclude that if the premise is true, then so is the hypothesis? Yes, No, or Maybe?",
28 | "{answer}"),
29 | (
30 | "Sentence 1: \"{premise}\" Sentence 2: \"{hypothesis}\" Is this second sentence entailed by the first sentence? Yes, No, or Maybe?",
31 | "{answer}"),
32 | (
33 | "Sentence 1: \"{premise}\" Sentence 2: \"{hypothesis}\" If the first sentence is true, then is the second sentence true? Yes, No, or Maybe?",
34 | "{answer}"),
35 | (
36 | "Based on the premise \"{premise}\", can we conclude the hypothesis \"{hypothesis}\" is true? Yes, No, or Maybe?",
37 | "{answer}"),
38 | (
39 | "Premise: \"{premise}\" If this premise is true, what does that tell us about whether it entails the hypothesis \"{hypothesis}\"? Yes, No, or Maybe?",
40 | "{answer}"),
41 | (
42 | "Premise: \"{premise}\" Based on this premise, is the hypothesis \"{hypothesis}\" true? Yes, No, or Maybe?",
43 | "{answer}"),
44 | ("If \"{premise}\", can we conclude that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"),
45 | ("\"{premise}\" Does it follow that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"),
46 | ]
47 |
48 | @property
49 | def possible_answers(self) -> Optional[List[str]]:
50 | return ['Yes', 'Maybe', 'No']
51 |
52 | @property
53 | def metric_name(self) -> str:
54 | return 'simple_accuracy'
55 |
56 | @property
57 | def task_name(self) -> str:
58 | return 'mnli'
59 |
60 |
61 | @task_map.add("mnli_m")
62 | class Mnli_m(Mnli):
63 |
64 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
65 | if split == 'train':
66 | return None
67 |
68 | return load_dataset('glue', 'mnli_matched', split='validation')
69 |
70 | @property
71 | def task_name(self) -> str:
72 | return 'mnli_m'
73 |
74 |
75 | @task_map.add("mnli_mm")
76 | class Mnli_mm(Mnli):
77 |
78 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
79 | if split == 'train':
80 | return None
81 |
82 | return load_dataset('glue', 'mnli_mismatched', split='validation')
83 |
84 | @property
85 | def task_name(self) -> str:
86 | return 'mnli_mm'
87 |
--------------------------------------------------------------------------------
/src/tasks/mrpc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("mrpc")
9 | class Mrpc(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('glue', 'mrpc', split=split)
13 | return dataset
14 |
15 | @property
16 | def templates(self) -> List[Tuple[str, str]]:
17 | return [
18 | ("Here are two sentences: {sentence1} {sentence2} Do they have the same meaning?", "{answer}"),
19 | (
20 | "Here are two sentences: {sentence1} {sentence2} Are the two sentences saying the same thing?",
21 | "{answer}"),
22 | ("{sentence1} {sentence2} Do the above sentences mean the same thing?", "{answer}"),
23 | ("{sentence1} {sentence2} Please tell me if the sentences above mean the same.", "{answer}"),
24 | ("{sentence1} {sentence2} Are these sentences conveying the same meaning?", "{answer}"),
25 | ("{sentence1} {sentence2} If the first sentence is true, is the second one also true?", "{answer}"),
26 | ("{sentence1} {sentence2} Are these two sentences paraphrases of each other?", "{answer}"),
27 | ("Do the following two sentences have the same meaning? {sentence1} {sentence2}", "{answer}"),
28 | ("Do these two sentences mean the same thing? {sentence1} {sentence2}", "{answer}"),
29 | ("Do these sentences have the same meaning? {sentence1} {sentence2}", "{answer}"),
30 | ]
31 |
32 | @property
33 | def possible_answers(self) -> Optional[List[str]]:
34 | return ['No', 'Yes']
35 |
36 | @property
37 | def metric_name(self) -> str:
38 | return 'acc_and_f1'
39 |
40 | @property
41 | def task_name(self) -> str:
42 | return 'mrpc'
43 |
--------------------------------------------------------------------------------
/src/tasks/multirc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("multirc")
9 | class Multirc(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('super_glue', 'multirc', split=split)
13 | dataset = dataset.rename_column('answer', 'response')
14 |
15 | return dataset
16 |
17 | @property
18 | def templates(self) -> List[Tuple[str, str]]:
19 | return [
20 | (
21 | "{paragraph} Question: \"{question}\" Response: \"{response}\" Does the response correctly answer the question?",
22 | "{answer}"),
23 | (
24 | "{paragraph} Question: \"{question}\" Response: \"{response}\" Based on the paragraph, is the response to the question is factually correct?",
25 | "{answer}"),
26 | ("{paragraph} Question: \"{question}\" Answer: \"{response}\" Is this answer correct?", "{answer}"),
27 | (
28 | "Paragraph: {paragraph} Question: \"{question}\" Answer: \"{response}\" Based on the paragraph, is this answer correct",
29 | "{answer}"),
30 | (
31 | "{paragraph} Based on the paragraph, does the response \"{response}\" correctly answer the question \"{question}\"?",
32 | "{answer}"),
33 | (
34 | "{paragraph} According to the above paragraph, the correct answer to the question \"{question}\" is \"{response}\"?",
35 | "{answer}"),
36 | (
37 | "{paragraph} After reading the above, is \"{response}\" the correct answer to the question \"{question}\"?",
38 | "{answer}"),
39 | ("{paragraph} Question: \"{question}\" Answer: \"{response}\" Is this answer to the question correct?",
40 | "{answer}"),
41 | ]
42 |
43 | @property
44 | def possible_answers(self) -> Optional[List[str]]:
45 | return ['No', 'Yes']
46 |
47 | @property
48 | def metric_name(self) -> str:
49 | return 'f1'
50 |
51 | @property
52 | def task_name(self) -> str:
53 | return 'multirc'
54 |
--------------------------------------------------------------------------------
/src/tasks/nq.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("natural_questions")
9 | class Natural_questions(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('nq_open', split=split)
13 | dataset = dataset.map(lambda ex: {'question': ex['question'] + '?'})
14 | return dataset
15 |
16 | @property
17 | def templates(self) -> List[Tuple[str, str]]:
18 | return [
19 | ("Question: {question} Answer:", "{answer}"),
20 | ("{question}", "{answer}"),
21 | ("Answer the following question: {question}", "{answer}"),
22 | ("Answer this question: {question}", "{answer}"),
23 | ("Please answer this question: {question}", "{answer}"),
24 | ("Answer the question...{question}", "{answer}"),
25 | ("What is the answer to this question? {question}", "{answer}"),
26 | ("Can you tell me the answer to {question}", "{answer}"),
27 | ("Next question: {question}", "{answer}"),
28 | ("Q: {question} A:", "{answer}"),
29 | ]
30 |
31 | @property
32 | def possible_answers(self) -> Optional[List[str]]:
33 | return None
34 |
35 | @property
36 | def metric_name(self) -> str:
37 | return 'trivia_qa'
38 |
39 | @property
40 | def task_name(self) -> str:
41 | return 'natural_questions'
42 |
43 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
44 | return example['answer']
45 |
--------------------------------------------------------------------------------
/src/tasks/openbookqa.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map, to_letter
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("openbookqa")
9 | class Openbookqa(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'test'
12 | dataset = load_dataset('openbookqa', 'additional', split=split)
13 |
14 | dataset = dataset.rename_column('fact1', 'fact')
15 | dataset = dataset.rename_column('question_stem', 'question')
16 |
17 | def _map_func(ex: Dict) -> Dict:
18 | ex['options'] = ex['choices']['text']
19 | return ex
20 |
21 | dataset = dataset.map(_map_func)
22 |
23 | return dataset
24 |
25 | @property
26 | def templates(self) -> List[Tuple[str, str]]:
27 | return [
28 | ("{fact} {question}", "{answer}"),
29 | ("Read this fact: \"{fact}\" Now answer this question: \"{question}\"", "{answer}"),
30 | ("Given the fact \"{fact}\", what is the answer to the question or completion \"{question}\"", "{answer}"),
31 | ("Knowing that \"{fact}\", how would one answer \"{question}\"", "{answer}"),
32 | ("Use evidence from the fact that {fact} to answer this question: \"{question}\"", "{answer}"),
33 | ("Fact: {fact} Question: {question} What's the answer?", "{answer}"),
34 | ("Use this fact to answer the question: {fact} {question}", "{answer}"),
35 | ]
36 |
37 | @property
38 | def possible_answers(self) -> Optional[List[str]]:
39 | return ['A', 'B', 'C', 'D']
40 |
41 | @property
42 | def metric_name(self) -> str:
43 | return 'simple_accuracy'
44 |
45 | @property
46 | def task_name(self) -> str:
47 | return 'openbookqa'
48 |
49 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
50 | return to_letter(example['answerKey'])
51 |
--------------------------------------------------------------------------------
/src/tasks/paws.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("paws")
9 | class Paws(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'test'
12 | dataset = load_dataset('paws', 'labeled_final', split=split)
13 | return dataset
14 |
15 | @property
16 | def templates(self) -> List[Tuple[str, str]]:
17 | return [
18 | ("{sentence1} {sentence2} Do these sentences mean the same thing?", "{answer}"),
19 | ("{sentence1} {sentence2} Are these two sentences paraphrases of each other?", "{answer}"),
20 | ("1. {sentence1} 2. {sentence2} Are these two sentences paraphrases of each other?", "{answer}"),
21 | ("(1) {sentence1} (2) {sentence2} Do these two sentences mean the same thing?", "{answer}"),
22 | ("Sentence 1: {sentence1} Sentence 2: {sentence2} Do these two sentences convey the same information?",
23 | "{answer}"),
24 | ("Do these two sentences from wikipedia have the same meaning? {sentence1} {sentence2}", "{answer}"),
25 | ("Same meaning? {sentence1} {sentence2}", "{answer}"),
26 | ("Are these paraphrases? {sentence1} {sentence2}", "{answer}"),
27 | ("Do these mean the same? {sentence1} {sentence2}", "{answer}"),
28 | (
29 | "Please check if these have the same meaning. Answer \"yes\" if they do, otherwise \"no\". {sentence1} {sentence2}",
30 | "{answer}"),
31 | ]
32 |
33 | @property
34 | def possible_answers(self) -> Optional[List[str]]:
35 | return ['No', 'Yes']
36 |
37 | @property
38 | def metric_name(self) -> str:
39 | return 'simple_accuracy'
40 |
41 | @property
42 | def task_name(self) -> str:
43 | return 'paws'
44 |
--------------------------------------------------------------------------------
/src/tasks/piqa.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map, to_letter
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("piqa")
9 | class Piqa(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('piqa', split=split)
13 |
14 | def _map_func(ex: Dict) -> Dict:
15 | ex['options'] = [ex['sol1'], ex['sol2']]
16 | return ex
17 |
18 | dataset = dataset.map(_map_func)
19 |
20 | return dataset
21 |
22 | @property
23 | def templates(self) -> List[Tuple[str, str]]:
24 | return [
25 | ("Here is a goal: \"{goal}\" How would you accomplish this goal?", "{answer}"),
26 | ("Here is a goal: \"{goal}\" Which way makes more sense to accomplish this goal?", "{answer}"),
27 | ("Goal: \"{goal}\" Which of the following methods is more reasonable for accomplishing this goal?",
28 | "{answer}"),
29 | ("BaseTaskive: \"{goal}\" Which of the following solutions is more sound in terms of naive physics reasoning?",
30 | "{answer}"),
31 | ("How do you do this: \"{goal}\"", "{answer}"),
32 | ("What is the best way to: \"{goal}\"", "{answer}"),
33 | ("Which of the following solutions is better for the following goal: \"{goal}\"", "{answer}"),
34 | ("How would someone go about accomplishing this goal? \"{goal}\"", "{answer}"),
35 | ]
36 |
37 | @property
38 | def possible_answers(self) -> Optional[List[str]]:
39 | return ['A', 'B']
40 |
41 | @property
42 | def metric_name(self) -> str:
43 | return 'simple_accuracy'
44 |
45 | @property
46 | def task_name(self) -> str:
47 | return 'piqa'
48 |
49 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
50 | return to_letter(example['label'])
51 |
--------------------------------------------------------------------------------
/src/tasks/qnli.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("qnli")
9 | class Qnli(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('glue', 'qnli', split=split)
13 | return dataset
14 |
15 | @property
16 | def templates(self) -> List[Tuple[str, str]]:
17 | return [
18 | ("Does the sentence \"{sentence}\" answer the question \"{question}\"?", "{answer}"),
19 | ("Does the sentence \"{sentence}\" provide a valid answer to the question \"{question}\"?", "{answer}"),
20 | ("Is \"{sentence}\" a good answer to the question \"{question}\"?", "{answer}"),
21 | ("Does \"{sentence}\" correctly answer the question of \"{question}\"?", "{answer}"),
22 | ("Does \"{sentence}\" contain the correct answer to \"{question}\"?", "{answer}"),
23 | ("Q: {question} A: {sentence} Does the answer correctly answer the question?", "{answer}"),
24 | (
25 | "Question: {question} Answer: {sentence} Is the question answered in a satisfactory fashion?",
26 | "{answer}"),
27 | ("Question: {question} Is {sentence} a good answer to this question?", "{answer}"),
28 | ("Question: {question} Is \"{sentence}\" the correct answer?", "{answer}"),
29 | ]
30 |
31 | @property
32 | def possible_answers(self) -> Optional[List[str]]:
33 | return ['Yes', 'No']
34 |
35 | @property
36 | def metric_name(self) -> str:
37 | return 'simple_accuracy'
38 |
39 | @property
40 | def task_name(self) -> str:
41 | return 'qnli'
42 |
--------------------------------------------------------------------------------
/src/tasks/qqp.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("qqp")
9 | class Qqp(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('glue', 'qqp', split=split)
13 |
14 | def _map_func(ex: Dict) -> Dict:
15 | ex['question1'] = ex['question1'].replace('""', '\'')
16 | ex['question2'] = ex['question2'].replace('""', '\'')
17 | return ex
18 |
19 | dataset = dataset.map(_map_func)
20 | return dataset
21 |
22 | @property
23 | def templates(self) -> List[Tuple[str, str]]:
24 | return [
25 | ("\"{question1}\" \"{question2}\" Would you say that these questions are the same?", "{answer}"),
26 | ("\"{question1}\" \"{question2}\" Do those questions have the same meaning?", "{answer}"),
27 | ("\"{question1}\" \"{question2}\" Are these two questions inquiring about the same information?", "{answer}"),
28 | ("\"{question1}\" \"{question2}\" Please tell me if those questions are the same.", "{answer}"),
29 | ("\"{question1}\" \"{question2}\" Are these two questions paraphrases of each other?", "{answer}"),
30 | ("First question: \"{question1}\" Second question: \"{question2}\" Are these two questions asking the same thing?",
31 | "{answer}"),
32 | (
33 | "Question 1: \"{question1}\" Question 2: \"{question2}\" Are questions 1 and 2 asking the same thing?",
34 | "{answer}"),
35 | ("Question 1: \"{question1}\" Question 2: \"{question2}\" Would the answer to these two questions be the same?",
36 | "{answer}"),
37 | ("Are the following two questions the same? \"{question1}\" \"{question2}\"", "{answer}"),
38 | ("Do these questions have the same meaning? \"{question1}\" \"{question2}\"", "{answer}"),
39 | ]
40 |
41 | @property
42 | def possible_answers(self) -> Optional[List[str]]:
43 | return ['No', 'Yes']
44 |
45 | @property
46 | def metric_name(self) -> str:
47 | return 'acc_and_f1'
48 |
49 | @property
50 | def task_name(self) -> str:
51 | return 'qqp'
52 |
--------------------------------------------------------------------------------
/src/tasks/rte.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("rte")
9 | class Rte(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('super_glue', 'rte', split=split)
13 | return dataset
14 |
15 | @property
16 | def templates(self) -> List[Tuple[str, str]]:
17 | return [
18 | ("{premise} Based on the paragraph above can we conclude that \"{hypothesis}\"? Yes or No?",
19 | "{answer}"),
20 | (
21 | "{premise} Based on that paragraph can we conclude that this sentence is true? {hypothesis} Yes or No?",
22 | "{answer}"),
23 | ("{premise} Can we draw the following conclusion? {hypothesis} Yes or No?", "{answer}"),
24 | ("{premise} Does this next sentence follow, given the preceding text? {hypothesis} Yes or No?",
25 | "{answer}"),
26 | ("{premise} Can we infer the following? {hypothesis} Yes or No?", "{answer}"),
27 | (
28 | "Read the following paragraph and determine if the hypothesis is true: {premise} Hypothesis: {hypothesis} Yes or No?",
29 | "{answer}"),
30 | ("Read the text and determine if the sentence is true: {premise} Sentence: {hypothesis} Yes or No?",
31 | "{answer}"),
32 | (
33 | "Can we draw the following hypothesis from the context? Context: {premise} Hypothesis: {hypothesis} Yes or No?",
34 | "{answer}"),
35 | ("Determine if the sentence is true based on the text below: {hypothesis} {premise} Yes or No?",
36 | "{answer}"),
37 | ]
38 |
39 | @property
40 | def possible_answers(self) -> Optional[List[str]]:
41 | return ['Yes', 'No']
42 |
43 | @property
44 | def metric_name(self) -> str:
45 | return 'simple_accuracy'
46 |
47 | @property
48 | def task_name(self) -> str:
49 | return 'rte'
50 |
--------------------------------------------------------------------------------
/src/tasks/sentiment140.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("sentiment140")
9 | class Sentiment140(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'test'
12 | dataset = load_dataset('sentiment140', split=split)
13 |
14 | def _map_func(ex: Dict) -> Dict:
15 | ex['label'] = 0 if int(ex['sentiment']) == 0 else 1
16 | return ex
17 |
18 | dataset = dataset.filter(lambda ex: int(ex['sentiment']) in [0, 4])
19 | dataset = dataset.map(_map_func)
20 |
21 | return dataset
22 |
23 | @property
24 | def templates(self) -> List[Tuple[str, str]]:
25 | return [
26 | ("{text} What is the sentiment of this tweet?", "{answer}"),
27 | ("{text} How would the sentiment of this tweet be described?", "{answer}"),
28 | ("{text} Describe the sentiment embodied by this tweet.", "{answer}"),
29 | ("Tweet: {text} Predict the sentiment of this tweet.", "{answer}"),
30 | ("What is the sentiment of the following tweet? Tweet:{text}", "{answer}"),
31 | ("How would one describe the sentiment of this tweet? {text}", "{answer}"),
32 | ]
33 |
34 | @property
35 | def possible_answers(self) -> Optional[List[str]]:
36 | return ['Negative', 'Positive']
37 |
38 | @property
39 | def metric_name(self) -> str:
40 | # Prior work uses two classes
41 | # (https://www.aclweb.org/anthology/C14-1008.pdf,
42 | # https://arxiv.org/pdf/1404.2188.pdf)
43 | return 'simple_accuracy'
44 |
45 | @property
46 | def task_name(self) -> str:
47 | return 'sentiment140'
48 |
--------------------------------------------------------------------------------
/src/tasks/snli.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | # define your task class
9 | @task_map.add("snli") # add your task to the task map
10 | class Snli(BaseTask):
11 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
12 | split = split if split == 'train' else 'test'
13 | dataset = load_dataset('snli', split=split)
14 | dataset = dataset.filter(lambda ex: int(ex["label"]) >= 0)
15 |
16 | return dataset
17 |
18 | @property
19 | def templates(self) -> List[Tuple[str, str]]:
20 | return [
21 | ("If \"{premise}\", does this mean that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"),
22 | ("If \"{premise}\", can we conclude \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"),
23 | ("If \"{premise}\", does it logically follow that \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"),
24 | (
25 | "Based on the sentence \"{premise}\", is the sentence \"{hypothesis}\" a true sentence? Yes, No, or Maybe?",
26 | "{answer}"),
27 | (
28 | "Premise: {premise} Hypothesis: {hypothesis} Can we conclude that the hypothesis is true if the premise is true? Yes, No, or Maybe?",
29 | "{answer}"),
30 | (
31 | "Premise: {premise} Hypothesis: {hypothesis} Given the premise, can we conclude the hypothesis? Yes, No, or Maybe?",
32 | "{answer}"),
33 | (
34 | "Here is a premise: \"{premise}\" Here is a hypothesis: \"{hypothesis}\". Does the premise tell us whether the hypothesis is true? Yes, No, or Maybe?",
35 | "{answer}"),
36 | ("Is it possible to conclude that \"{premise}\" if \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"),
37 | ("Is the premise \"{premise}\" true if \"{hypothesis}\"? Yes, No, or Maybe?", "{answer}"),
38 | ]
39 |
40 |
41 | @property
42 | def possible_answers(self) -> Optional[List[str]]:
43 | return ['Yes', 'Maybe', 'No']
44 |
45 | @property
46 | def metric_name(self) -> str:
47 | return 'simple_accuracy'
48 |
49 | @property
50 | def task_name(self) -> str:
51 | return 'snli'
52 |
--------------------------------------------------------------------------------
/src/tasks/squad_v1.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from typing import Optional, List, Tuple, Dict, Union
4 | from datasets import load_dataset, Dataset
5 |
6 | from tasks import task_map
7 | from tasks.base_task import BaseTask
8 |
9 |
10 | @task_map.add("squad_v1")
11 | class Squad_v1(BaseTask):
12 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
13 | split = split if split == 'train' else 'validation'
14 | dataset = load_dataset('squad', split=split)
15 |
16 | def _map_func(ex: Dict) -> Dict:
17 | ex['title'] = re.sub(r'_', ' ', ex['title'])
18 | return ex
19 |
20 | dataset = dataset.map(_map_func)
21 |
22 | return dataset
23 |
24 | @property
25 | def templates(self) -> List[Tuple[str, str]]:
26 | return [
27 | ("Please answer a question about the following article about {title}: {context} {question}", "{answer}"),
28 | ("Read this and answer the question {context} {question}", "{answer}"),
29 | ("{context} {question}", "{answer}"),
30 | ("Answer a question about this article: {context} {question}", "{answer}"),
31 | ("Here is a question about this article: {context} What is the answer to this question: {question}",
32 | "{answer}"),
33 | ("Article: {context} Question: {question}", "{answer}"),
34 | ("Article: {context} Now answer this question: {question}", "{answer}"),
35 | ("{title} {context} Q: {question}", "{answer}"),
36 | ]
37 |
38 | @property
39 | def possible_answers(self) -> Optional[List[str]]:
40 | return None
41 |
42 | @property
43 | def metric_name(self) -> str:
44 | return 'squad'
45 |
46 | @property
47 | def task_name(self) -> str:
48 | return 'squad_v1'
49 |
50 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
51 | return example['answers']['text']
52 |
--------------------------------------------------------------------------------
/src/tasks/sst2.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("sst2")
9 | class Sst2(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('sst2', split=split)
13 | return dataset
14 |
15 | @property
16 | def templates(self) -> List[Tuple[str, str]]:
17 | return [
18 | ("Review: \"{sentence}\" Is this movie review sentence negative or positive?", "{answer}"),
19 | ("Short movie review: \"{sentence}\" Did the critic thinking positively or negatively of the movie?",
20 | "{answer}"),
21 | (
22 | "Sentence from a movie review: \"{sentence}\" Was the movie seen positively or negatively based on the preceding review?",
23 | "{answer}"),
24 | ("\"{sentence}\" How would the sentiment of this sentence be perceived?", "{answer}"),
25 | ("Is the sentiment of the following sentence positive or negative? \"{sentence}\"", "{answer}"),
26 | ("What is the sentiment of the following movie review sentence? \"{sentence}\"", "{answer}"),
27 | ("Would the following phrase be considered positive or negative? \"{sentence}\"", "{answer}"),
28 | ("Does the following review have a positive or negative opinion of the movie? \"{sentence}\"", "{answer}"),
29 | ]
30 |
31 | @property
32 | def possible_answers(self) -> Optional[List[str]]:
33 | return ['Negative', 'Positive']
34 |
35 | @property
36 | def metric_name(self) -> str:
37 | return 'simple_accuracy'
38 |
39 | @property
40 | def task_name(self) -> str:
41 | return 'sst2'
42 |
--------------------------------------------------------------------------------
/src/tasks/winogrande.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map, to_letter
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("winogrande")
9 | class Winogrande(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('winogrande', 'winogrande_xl', split=split)
13 |
14 | def _map_func(ex: Dict) -> Dict:
15 | cut_index = ex["sentence"].index('_')
16 | context = ex["sentence"][:cut_index]
17 | ex['context'] = context.strip()
18 |
19 | text_second = ex["sentence"][cut_index + 1:]
20 | ex['options'] = [ex["option1"] + text_second, ex["option2"] + text_second]
21 |
22 | return ex
23 |
24 | dataset = dataset.map(_map_func)
25 |
26 | return dataset
27 |
28 | @property
29 | def templates(self) -> List[Tuple[str, str]]:
30 | return [
31 | ("How does the sentence end? {context}", "{answer}"),
32 | ("Write the next sentence. {context}", "{answer}"),
33 | ("Continue the following story. {context}", "{answer}"),
34 | ("Complete the following sentence. {context}", "{answer}"),
35 | ("Continue writing the following text. {context}", "{answer}"),
36 | ("How does the sentence end? {context}", "{answer}"),
37 | ("Write the next sentence. {context}", "{answer}"),
38 | ("Continue the following story. {context}", "{answer}"),
39 | ("Complete the following sentence. {context}", "{answer}"),
40 | ("Continue writing the following text. {context}", "{answer}"),
41 | ]
42 |
43 | @property
44 | def possible_answers(self) -> Optional[List[str]]:
45 | return ['A', 'B']
46 |
47 | @property
48 | def metric_name(self) -> str:
49 | return 'simple_accuracy'
50 |
51 | @property
52 | def task_name(self) -> str:
53 | return 'winogrande'
54 |
55 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
56 | label = int(example['answer']) - 1
57 | return to_letter(label)
58 |
--------------------------------------------------------------------------------
/src/tasks/wsc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("wsc")
9 | class Wsc(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | split = split if split == 'train' else 'validation'
12 | dataset = load_dataset('super_glue', 'wsc', split=split)
13 |
14 | dataset = dataset.rename_column('text', 'context')
15 | dataset = dataset.rename_column('span1_text', 'text1')
16 | dataset = dataset.rename_column('span2_text', 'text2')
17 |
18 | return dataset
19 |
20 | @property
21 | def templates(self) -> List[Tuple[str, str]]:
22 | return [
23 | ("{context} Are \"{text1}\" and \"{text2}\" the same entity?", "{answer}"),
24 | ("{context} Do \"{text1}\" and \"{text2}\" have the same meaning?", "{answer}"),
25 | ("Given the following context {context} Are \"{text1}\" and \"{text2}\" the same?", "{answer}"),
26 | ("{context} Do \"{text2}\" and \"{text1}\" mean the same thing?", "{answer}"),
27 | ("{context} Are \"{text2}\" and \"{text1}\" the same thing in the aforementioned sentence?", "{answer}"),
28 | ("Context:{context} Is \"{text2}\" the same as \"{text1}\"?", "{answer}"),
29 | ("Consider this sentence: {context} Are \"{text2}\" and \"{text1}\" the same?", "{answer}"),
30 | ("Are \"{text1}\" and \"{text2}\" the same in this sentence? {context}", "{answer}"),
31 | ("Is \"{text1}\" the same as \"{text2}\" in this sentence? {context}", "{answer}"),
32 | ("Do \"{text1}\" and \"{text2}\" point to the same thing in the following sentence? {context}", "{answer}"),
33 | ]
34 |
35 | @property
36 | def possible_answers(self) -> Optional[List[str]]:
37 | return ['No', 'Yes']
38 |
39 | @property
40 | def metric_name(self) -> str:
41 | return 'simple_accuracy'
42 |
43 | @property
44 | def task_name(self) -> str:
45 | return 'wsc'
46 |
--------------------------------------------------------------------------------
/src/tasks/wsc273.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple, Dict, Union
2 | from datasets import load_dataset, Dataset
3 |
4 | from tasks import task_map, to_letter
5 | from tasks.base_task import BaseTask
6 |
7 |
8 | @task_map.add("wsc273")
9 | class Wsc273(BaseTask):
10 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
11 | if split == 'train':
12 | return None
13 |
14 | dataset = load_dataset('winograd_wsc', 'wsc273', split='test')
15 |
16 | def _map_func(ex: Dict) -> Dict:
17 | text_first = ex["text"][:ex["pronoun_loc"]]
18 | ex['context'] = text_first
19 |
20 | text_second = ex["text"][ex["pronoun_loc"] + len(ex["pronoun"]):]
21 | ex['options'] = [ex["options"][0] + text_second, ex["options"][1] + text_second]
22 |
23 | return ex
24 |
25 | dataset = dataset.map(_map_func)
26 |
27 | return dataset
28 |
29 | @property
30 | def templates(self) -> List[Tuple[str, str]]:
31 | return [
32 | ("{context}", "{answer}"),
33 | ("Complete the passage. {context}", "{answer}"),
34 | ("How does this following sentence end? {context}", "{answer}"),
35 | ("What is the most logical completion for the following text? {context}", "{answer}"),
36 | ("How does this text end? {context}", "{answer}"),
37 | ("What happens next? {context}", "{answer}"),
38 | ("Complete the following sentence. {context}", "{answer}"),
39 | ("Fill in the remainder of the sentence. {context}", "{answer}"),
40 | ("What is the next event? {context}", "{answer}"),
41 | ("Complete the rest of the sentence. {context}", "{answer}"),
42 | ]
43 |
44 | @property
45 | def possible_answers(self) -> Optional[List[str]]:
46 | return ['A', 'B']
47 |
48 | @property
49 | def metric_name(self) -> str:
50 | return 'simple_accuracy'
51 |
52 | @property
53 | def task_name(self) -> str:
54 | return 'wsc273'
55 |
56 | def get_answer(self, example: Dict) -> Union[str, List[str]]:
57 | return to_letter(example['label'])
58 |
--------------------------------------------------------------------------------
/src/tasks/yelp.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from typing import Optional, List, Tuple, Dict
4 | from datasets import load_dataset, Dataset
5 |
6 | from tasks import task_map
7 | from tasks.base_task import BaseTask
8 |
9 |
10 | @task_map.add("yelp")
11 | class Yelp(BaseTask):
12 | def _load_raw_data(self, split: str) -> Optional[Dataset]:
13 | split = split if split == 'train' else 'test'
14 | dataset = load_dataset('yelp_polarity', split=split)
15 |
16 | def _map_func(ex: Dict) -> Dict:
17 | ex['text'] = re.sub(r'\\\"', '', ex['text'])
18 | ex['text'] = re.sub(r'\\n\\n', ' ', ex['text'])
19 | return ex
20 |
21 | dataset = dataset.map(_map_func)
22 | dataset = dataset.filter(lambda ex: 0 < len(ex['text'].split()) <= 256)
23 |
24 | return dataset
25 |
26 | @property
27 | def templates(self) -> List[Tuple[str, str]]:
28 | return [
29 | ("{text} Is this review positive or negative?", "{answer}"),
30 | ("{text} What is the sentiment of this review?", "{answer}"),
31 | ("{text} Was this review given positively or negatively?", "{answer}"),
32 | ("{text} How would this review be described in terms of sentiment?", "{answer}"),
33 | ("Is the following review positive or negative? {text}", "{answer}"),
34 | ("What is the sentiment of the following review? {text}", "{answer}"),
35 | ("How might one describe the sentiment of this review? {text}", "{answer}"),
36 | ]
37 |
38 | @property
39 | def possible_answers(self) -> Optional[List[str]]:
40 | return ['Negative', 'Positive']
41 |
42 | @property
43 | def metric_name(self) -> str:
44 | return 'simple_accuracy'
45 |
46 | @property
47 | def task_name(self) -> str:
48 | return 'yelp'
49 |
--------------------------------------------------------------------------------
/src/train_biencoder.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info, set_verbosity_warning
4 | from transformers.trainer_callback import PrinterCallback
5 | from transformers import (
6 | AutoTokenizer,
7 | HfArgumentParser,
8 | Trainer,
9 | set_seed,
10 | PreTrainedTokenizerFast
11 | )
12 |
13 | from logger_config import logger, LoggerCallback
14 | from config import Arguments
15 | from trainers import BiencoderTrainer
16 | from loaders import RetrievalDataLoader
17 | from collators import BiencoderCollator
18 | from models import BiencoderModel
19 |
20 |
21 | def _common_setup(args: Arguments):
22 | set_verbosity_info()
23 | if args.process_index > 0:
24 | logger.setLevel(logging.WARNING)
25 | set_verbosity_warning()
26 | enable_explicit_format()
27 | set_seed(args.seed)
28 |
29 |
30 | def main():
31 | parser = HfArgumentParser((Arguments,))
32 | args: Arguments = parser.parse_args_into_dataclasses()[0]
33 | _common_setup(args)
34 | logger.info('Args={}'.format(str(args)))
35 |
36 | tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
37 | model: BiencoderModel = BiencoderModel.build(args=args)
38 | logger.info(model)
39 | logger.info('Vocab size: {}'.format(len(tokenizer)))
40 |
41 | data_collator = BiencoderCollator(
42 | args=args,
43 | tokenizer=tokenizer,
44 | pad_to_multiple_of=8 if args.fp16 else None)
45 |
46 | retrieval_data_loader = RetrievalDataLoader(args=args, tokenizer=tokenizer)
47 | train_dataset = retrieval_data_loader.train_dataset
48 |
49 | trainer: Trainer = BiencoderTrainer(
50 | model=model,
51 | args=args,
52 | train_dataset=train_dataset if args.do_train else None,
53 | data_collator=data_collator,
54 | tokenizer=tokenizer,
55 | )
56 | trainer.remove_callback(PrinterCallback)
57 | trainer.add_callback(LoggerCallback)
58 | retrieval_data_loader.set_trainer(trainer)
59 | model.trainer = trainer
60 |
61 | if args.do_train:
62 | train_result = trainer.train()
63 | trainer.save_model()
64 |
65 | metrics = train_result.metrics
66 | metrics["train_samples"] = len(train_dataset)
67 |
68 | trainer.log_metrics("train", metrics)
69 | trainer.save_metrics("train", metrics)
70 |
71 | return
72 |
73 |
74 | if __name__ == "__main__":
75 | main()
76 |
--------------------------------------------------------------------------------
/src/train_cross_encoder.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from transformers.utils.logging import enable_explicit_format, set_verbosity_info, set_verbosity_warning
4 | from transformers.trainer_callback import PrinterCallback
5 | from transformers import (
6 | AutoTokenizer,
7 | HfArgumentParser,
8 | Trainer,
9 | set_seed,
10 | PreTrainedTokenizerFast
11 | )
12 |
13 | from logger_config import logger, LoggerCallback
14 | from config import Arguments
15 | from trainers.reward_trainer import RewardTrainer
16 | from loaders import CrossEncoderDataLoader
17 | from collators import CrossEncoderCollator
18 | from models import Reranker
19 |
20 |
21 | def _common_setup(args: Arguments):
22 | set_verbosity_info()
23 | if args.process_index > 0:
24 | logger.setLevel(logging.WARNING)
25 | set_verbosity_warning()
26 | enable_explicit_format()
27 | set_seed(args.seed)
28 |
29 |
30 | def main():
31 | parser = HfArgumentParser((Arguments,))
32 | args: Arguments = parser.parse_args_into_dataclasses()[0]
33 | _common_setup(args)
34 | logger.info('Args={}'.format(str(args)))
35 |
36 | tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
37 |
38 | model: Reranker = Reranker.from_pretrained(
39 | all_args=args,
40 | pretrained_model_name_or_path=args.model_name_or_path,
41 | num_labels=1)
42 |
43 | logger.info(model)
44 | logger.info('Vocab size: {}'.format(len(tokenizer)))
45 |
46 | data_collator = CrossEncoderCollator(
47 | tokenizer=tokenizer,
48 | pad_to_multiple_of=8 if args.fp16 else None)
49 |
50 | reward_data_loader = CrossEncoderDataLoader(args=args, tokenizer=tokenizer)
51 | train_dataset = reward_data_loader.train_dataset
52 |
53 | trainer: Trainer = RewardTrainer(
54 | model=model,
55 | args=args,
56 | train_dataset=train_dataset if args.do_train else None,
57 | data_collator=data_collator,
58 | tokenizer=tokenizer,
59 | )
60 | trainer.remove_callback(PrinterCallback)
61 | trainer.add_callback(LoggerCallback)
62 | reward_data_loader.trainer = trainer
63 |
64 | if args.do_train:
65 | train_result = trainer.train()
66 | trainer.save_model()
67 |
68 | metrics = train_result.metrics
69 | metrics["train_samples"] = len(train_dataset)
70 |
71 | trainer.log_metrics("train", metrics)
72 | trainer.save_metrics("train", metrics)
73 |
74 | return
75 |
76 |
77 | if __name__ == "__main__":
78 | main()
79 |
--------------------------------------------------------------------------------
/src/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .biencoder_trainer import BiencoderTrainer
2 | from .reward_trainer import RewardTrainer
3 |
--------------------------------------------------------------------------------
/src/trainers/biencoder_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from typing import Optional
4 | from transformers.trainer import Trainer
5 |
6 | from logger_config import logger
7 | from evaluation.metrics import accuracy, batch_mrr
8 | from models import BiencoderOutput, BiencoderModel
9 | from utils import AverageMeter
10 |
11 |
12 | class BiencoderTrainer(Trainer):
13 | def __init__(self, *pargs, **kwargs):
14 | super(BiencoderTrainer, self).__init__(*pargs, **kwargs)
15 | self.model: BiencoderModel
16 |
17 | self.acc1_meter = AverageMeter('Acc@1', round_digits=2)
18 | self.acc3_meter = AverageMeter('Acc@3', round_digits=2)
19 | self.mrr_meter = AverageMeter('mrr', round_digits=2)
20 | self.last_epoch = 0
21 |
22 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
23 | output_dir = output_dir if output_dir is not None else self.args.output_dir
24 | os.makedirs(output_dir, exist_ok=True)
25 | logger.info("Saving model checkpoint to {}".format(output_dir))
26 | self.model.save(output_dir)
27 | if self.tokenizer is not None:
28 | self.tokenizer.save_pretrained(output_dir)
29 |
30 | def compute_loss(self, model, inputs, return_outputs=False):
31 | outputs: BiencoderOutput = model(inputs)
32 | loss = outputs.loss
33 |
34 | if self.model.training:
35 | step_acc1, step_acc3 = accuracy(output=outputs.scores.detach(), target=outputs.labels, topk=(1, 3))
36 | step_mrr = batch_mrr(output=outputs.scores.detach(), target=outputs.labels)
37 |
38 | self.acc1_meter.update(step_acc1)
39 | self.acc3_meter.update(step_acc3)
40 | self.mrr_meter.update(step_mrr)
41 |
42 | if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
43 | logger.info('step: {}, {}, {}, {}'.format(self.state.global_step, self.mrr_meter, self.acc1_meter, self.acc3_meter))
44 |
45 | self._reset_meters_if_needed()
46 |
47 | return (loss, outputs) if return_outputs else loss
48 |
49 | def _reset_meters_if_needed(self):
50 | if int(self.state.epoch) != self.last_epoch:
51 | self.last_epoch = int(self.state.epoch)
52 | self.acc1_meter.reset()
53 | self.acc3_meter.reset()
54 | self.mrr_meter.reset()
55 |
--------------------------------------------------------------------------------
/src/trainers/reward_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from typing import Optional
4 | from transformers.trainer import Trainer, IntervalStrategy
5 | from transformers.modeling_outputs import SequenceClassifierOutput
6 |
7 | from logger_config import logger
8 | from evaluation.metrics import accuracy
9 | from utils import AverageMeter
10 |
11 |
12 | class RewardTrainer(Trainer):
13 |
14 | def __init__(self, *pargs, **kwargs):
15 | super(RewardTrainer, self).__init__(*pargs, **kwargs)
16 |
17 | self.acc_meter = AverageMeter('acc', round_digits=2)
18 | self.last_epoch = 0
19 |
20 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
21 | output_dir = output_dir if output_dir is not None else self.args.output_dir
22 | os.makedirs(output_dir, exist_ok=True)
23 | logger.info("Saving model checkpoint to {}".format(output_dir))
24 |
25 | self.model.save_pretrained(output_dir)
26 |
27 | if self.tokenizer is not None and self.is_world_process_zero():
28 | self.tokenizer.save_pretrained(output_dir)
29 |
30 | def compute_loss(self, model, inputs, return_outputs=False):
31 | outputs: SequenceClassifierOutput = model(inputs)
32 | loss = outputs.loss
33 |
34 | if self.model.training:
35 | labels = inputs['labels']
36 | step_acc = accuracy(output=outputs.logits.detach(), target=labels)[0]
37 | self.acc_meter.update(step_acc)
38 | if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
39 | logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter))
40 |
41 | self._reset_meters_if_needed()
42 |
43 | return (loss, outputs) if return_outputs else loss
44 |
45 | def _reset_meters_if_needed(self):
46 | if int(self.state.epoch) != self.last_epoch:
47 | self.last_epoch = int(self.state.epoch)
48 | self.acc_meter.reset()
49 | if self.args.save_strategy == IntervalStrategy.STEPS and self.state.global_step % self.args.save_steps == 0:
50 | self.acc_meter.reset()
51 |
--------------------------------------------------------------------------------