├── pairs ├── __init__.py ├── eval_dataset.py ├── pairs_eval.py ├── llama2.py ├── full_preference_matrix.py ├── mistral.py ├── prompts.py ├── openai_api.py ├── sorting.py └── utils.py ├── figs └── zepo.png ├── run_zepo.sh ├── install.sh ├── LICENSE ├── README.md ├── init_prompts.json ├── models ├── llama2.py ├── openai_api.py ├── mistral.py └── llama3.py ├── prompts.py ├── zepo.py ├── pairwise_comparison.py └── utils.py /pairs/__init__.py: -------------------------------------------------------------------------------- 1 | from .pairs_eval import get_corr_df 2 | -------------------------------------------------------------------------------- /figs/zepo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cambridgeltl/zepo/HEAD/figs/zepo.png -------------------------------------------------------------------------------- /run_zepo.sh: -------------------------------------------------------------------------------- 1 | python3 zepo.py \ 2 | --dataset='SummEval' \ 3 | --engine='meta-llama/Meta-Llama-3-8B-Instruct' \ 4 | --aspect_name='coherence' \ 5 | --eval_data_num=10 \ 6 | --sample_num=5 \ 7 | --epoch_num=5 \ 8 | --batch_size=4 9 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | conda create -n zepo python=3.10 -y 2 | conda activate zepo 3 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 4 | pip install transformers 5 | pip install wandb 6 | pip install openai 7 | pip install jupyter 8 | pip install accelerate 9 | conda install -c conda-forge cudatoolkit-dev 10 | pip install flash-attn --no-build-isolation 11 | pip install bitsandbytes 12 | 13 | pip install torchtext 14 | pip install pandas 15 | pip install datasets 16 | pip install sentencepiece 17 | pip install mistralai 18 | pip install scikit-learn 19 | pip install scipy 20 | 21 | pip install sacrebleu 22 | pip install rouge 23 | pip install termcolor -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Cambridge Language Technology Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for Fairer Preferences Elicit Improved Human-Aligned Large Language Model Judgments 2 | 3 | ![zepo](figs/zepo.png) 4 | **Link to paper**: 5 | [Fairer Preferences Elicit Improved Human-Aligned Large Language Model Judgments](https://arxiv.org/abs/2406.11370). (Accepted to EMNLP 2024 Main) 6 | 7 | ## Installation 8 | 9 | Run `source install.sh` to quickly create an virtual environment named ZEPO and install all dependencies. 10 | 11 | ``` 12 | source install.sh 13 | ``` 14 | 15 | Setup the OPENAI API for calling GPT-3.5-turbo. 16 | 17 | ``` 18 | export OPENAI_API_KEY=[YOUR_KEY] 19 | ``` 20 | 21 | ## Quick Start: ZEPO 22 | 23 | We provide a default script to reproduce our main experiments in optimizing the prompts for COH-SummEval. 24 | 25 | ``` 26 | sh run_zepo.sh 27 | ``` 28 | 29 | ## Prompt Optimization 30 | 31 | 1. **Design your own prompts:** We provide default instruction to be optimized at ```init_prompts.json``` and the prompt template for corresponding tasks at ```prompts.py```. It is interesting to explore different initialization prompts and inspect the fairness. 32 | 33 | 2. **Use your own datasets:** Feel free to check the data formats that we have implemented in ```\data```, and it is easy to adapt other datasets to the same format and implement your own dataloader for conducting pairwise comparisons. 34 | 35 | 3. **Support more advanced LLM prompt optimizers:** We have only studied a basic LLM optimizer to paraphrase the prompt and do greedy search in ```zepo.py```. Having more advanced LLM optimizers can enable more exploitation-driven prompt search in more diverse instruction formats. 36 | 37 | 4. **ZEPO mitgates biases** Ideally, given a sufficient number of random sampling in pairwise comparisons, ZEPO can always optimize toward fairness, such that all representative biases can be mitigated in one goal, including position bias, verbosity bias, and self-preference bias. We welcome all extension study with respect to the impact of these biases and the extent that they are mitigated by ZEPO. 38 | 39 | We acknowledge that some scripts were modified based on PairS. 40 | 41 | ## Citation 42 | 43 | If you find our work to be useful, please cite: 44 | 45 | ``` 46 | @article{zhou2024fairer, 47 | title={Fairer Preferences Elicit Improved Human-Aligned Large Language Model Judgments}, 48 | author={Zhou, Han and Wan, Xingchen and Liu, Yinhong and Collier, Nigel and Vulić, Ivan and Korhonen, Anna}, 49 | journal={arXiv preprint arXiv:2406.11370}, 50 | year={2024} 51 | } 52 | ``` 53 | -------------------------------------------------------------------------------- /init_prompts.json: -------------------------------------------------------------------------------- 1 | { 2 | "coherence": "Evaluate and compare the coherence of the two summary candidates for the given source text. Consider coherence aspects such as clarity and logical flow. A summary is coherent if it accurately captures the key information from the article, and presents them in a clear manner. Which summary candidate has better coherence? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only.", 3 | "fluency": "Evaluate and compare the fluency of the two summary candidates for the given source text. Which summary candidate has better fluency? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only.", 4 | "consistency": "Evaluate and compare the consistency of the two summary candidates for the given source text. A summary is consistent with the article if it faithfully reflects the main points, facts, and tone of the article. A summary is inconsistent if it introduces any errors, contradictions, or distortions of the original article. Which summary candidate has better consistency? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only.", 5 | "relevance": "Evaluate and compare the relevance of the two summary candidates for the given source text. A summary is relevant if it captures the main points from the article, without leaving out any crucial details or adding any unnecessary or inaccurate ones. A summary is more relevant if it uses the same or similar terms and expressions as the article. A summary is less relevant if it omits some of the key facts from the article, or if it introduces irrelevant information that is not supported by the article. Which summary candidate has better relevance? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only.", 6 | "informativeness": "Evaluate and compare the informativeness of the two summary candidates for the given source text. Evaluate how each summary converts their input text to natural language text, without omitting, adding, or distorting any facts. Which summary candidate has better informativeness? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only.", 7 | "engaging": "Evaluate and compare the engagement of the two response candidates for the given dialog history. Which response candidate is more engaging? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only.", 8 | "natural": "Evaluate and compare the naturalness of the two response candidates for the given dialog history. Which response candidate is more natural? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only.", 9 | "overall": "Evaluate and compare the overall response of the two candidates for the given dialog history. Which response candidate has better quality? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only." 10 | } -------------------------------------------------------------------------------- /pairs/eval_dataset.py: -------------------------------------------------------------------------------- 1 | from utils import ( 2 | shuffle_lists, 3 | calculate_correlation, 4 | load_newsroom, 5 | load_summEval, 6 | calculate_uncertainty, 7 | load_sf_data, 8 | CompareResultObject, 9 | insert_index_to_anchors, 10 | ) 11 | import random 12 | from sorting import merge_sort_indices 13 | import numpy as np 14 | from tqdm import tqdm 15 | import json 16 | 17 | 18 | if __name__ == "__main__": 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--dataset", type=str, default="SumEval") 23 | parser.add_argument("--save_path", type=str, default="./results.jsonl") 24 | parser.add_argument("--aspect", type=str, default="coherence") 25 | parser.add_argument("--eval_method", type=str, default="pairwise comparison") 26 | parser.add_argument("--scaling_anchor_size", type=int, default=0) 27 | parser.add_argument("--eval_size", type=int, default=300) 28 | parser.add_argument( 29 | "--engine", type=str, default="mistralai/Mistral-7B-Instruct-v0.1" 30 | ) 31 | parser.add_argument("--confidence_beam", action="store_true") 32 | parser.add_argument("--prob_gap", type=float, default=0.15) 33 | parser.add_argument("--beam_size", type=int, default=100) 34 | parser.add_argument("--with_input", action="store_true") 35 | parser.add_argument("--calibration", action="store_true") 36 | args = parser.parse_args() 37 | 38 | print("aspect:", args.aspect) 39 | print("engine:", args.engine) 40 | print("dataset:", args.dataset) 41 | print("confidence_beam:", args.confidence_beam) 42 | print("beam_size:", args.beam_size) 43 | print("calibration:", args.calibration) 44 | 45 | params = { 46 | "dataset": args.dataset, 47 | "engine": args.engine, 48 | "aspect": args.aspect, 49 | "eval_method": args.eval_method, 50 | "confidence_beam": args.confidence_beam, 51 | "beam_size": args.beam_size, 52 | "api_call": 0, 53 | "prob_gap": args.prob_gap, 54 | "with_input": args.with_input, 55 | "compare_log": {}, 56 | "calibration": args.calibration, 57 | } 58 | # Load the dataset 59 | if args.dataset == "SumEval": 60 | summ_eval_path = "data/SummEval/model_annotations.aligned.paired.jsonl" 61 | input_doc, output_doc, scores_doc = load_summEval( 62 | summ_eval_path, flat_output=False 63 | ) 64 | elif args.dataset == "newsroom": 65 | newsroom_path = "data/newsroom/newsroom.json" 66 | input_doc, output_doc, scores_doc = load_newsroom( 67 | newsroom_path, flat_output=False 68 | ) 69 | else: 70 | print("Dataset not supported.") 71 | assert False 72 | 73 | scores_doc = scores_doc[args.aspect] 74 | ranking_indices_list = [] 75 | scores_list = [] 76 | progress_bar = tqdm(total=len(input_doc), desc="Processing") 77 | base_idx_cnt = 0 78 | spearman_corr_list, kendall_tau_list = [], [] 79 | for input, output, scores in zip(input_doc, output_doc, scores_doc): 80 | input, output, scores = shuffle_lists(input, output, scores) 81 | ranking_indices = merge_sort_indices(input, output, params) 82 | ranking_indices_list.append([idx + base_idx_cnt for idx in ranking_indices]) 83 | scores_list.append(scores) 84 | base_idx_cnt += len(input) 85 | progress_bar.update(1) 86 | print(np.array(scores)[ranking_indices]) 87 | spearman_corr, kendall_tau, mae = calculate_correlation( 88 | np.array(scores)[ranking_indices], list(range(len(scores))) 89 | ) 90 | spearman_corr_list.append(spearman_corr) 91 | kendall_tau_list.append(kendall_tau) 92 | print("api_call:", params["api_call"]) 93 | params["api_call"] = 0 94 | 95 | ranking_indices_flatten = np.array(ranking_indices_list).T.flatten().tolist() 96 | scores_flatten = np.array(scores_list).flatten() 97 | # Save the results if needed 98 | results = { 99 | "aspect": args.aspect, 100 | "confidence_beam": args.confidence_beam, 101 | "beam_size": params["beam_size"], 102 | "engine": args.engine, 103 | "dataset": args.dataset, 104 | "human_scores": scores_flatten.tolist(), 105 | "gpt_ranking": ranking_indices_flatten, 106 | "compare_log": {str(key): val for key, val in params["compare_log"].items()}, 107 | "spearmans:": np.average(spearman_corr_list).tolist(), 108 | "kendall_tau": np.average(kendall_tau_list).tolist(), 109 | } 110 | 111 | progress_bar.close() 112 | print("---------------------------------") 113 | print("spearmans:", np.mean(spearman_corr_list)) 114 | print("kendall_tau:", np.mean(kendall_tau_list)) 115 | print("aspect:", args.aspect) 116 | print("engine:", args.engine) 117 | print("dataset:", args.dataset) 118 | print("confidence_beam:", args.confidence_beam) 119 | print("beam_size:", params["beam_size"]) 120 | print("calibration:", args.calibration) 121 | -------------------------------------------------------------------------------- /models/llama2.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | sys.path.append("../") 8 | from utils import CompareResultObject, calculate_uncertainty 9 | 10 | 11 | device = "cuda" 12 | 13 | 14 | def is_integer_string(s): 15 | return s.isdigit() 16 | 17 | 18 | class Llama2ModelLocal: 19 | def __init__(self, params): 20 | self.model_name = params["model"] 21 | self.temperature = params["temperature"] if "temperature" in params else 0 22 | self.max_tokens = params["max_tokens"] if "max_tokens" in params else 64 23 | self.do_sample = params["do_sample"] if "do_sample" in params else False 24 | self.device = device 25 | if "cache_dir" not in params: 26 | params["cache_dir"] = None 27 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) 28 | load_in_8bit = True if "13" in self.model_name else False 29 | self.model = AutoModelForCausalLM.from_pretrained( 30 | self.model_name, 31 | torch_dtype=torch.bfloat16, 32 | device_map=self.device, 33 | load_in_8bit=load_in_8bit, 34 | attn_implementation="flash_attention_2", 35 | ) 36 | self.model.eval() 37 | self.A_ids = self.tokenizer.convert_tokens_to_ids(["A", "▁A"]) # A: 330 38 | self.B_ids = self.tokenizer.convert_tokens_to_ids(["B", "▁B"]) # B: 365 39 | self.C_ids = self.tokenizer.convert_tokens_to_ids(["C", "▁C"]) # C: 40 | self.score_ids = self.tokenizer.convert_tokens_to_ids(["1", "2", "3", "4", "5"]) 41 | 42 | def rate_score(self, prompt): 43 | sequence, output = self.local_model_chat_completion(prompt) 44 | score, logprobs = self.extract_score(sequence, output.logits) 45 | print(score) 46 | return score, logprobs 47 | 48 | def compare(self, prompt) -> CompareResultObject: 49 | # model run locally 50 | # print(prompt) 51 | sequence, output = self.local_model_chat_completion(prompt) 52 | # print(self.tokenizer.batch_decode(sequence)) 53 | compare_result = self.extract_probs(sequence, output.logits) 54 | return compare_result 55 | 56 | def generate(self, prompt, chat_system_instruction=None, num_samples=1): 57 | sequence, output = self.local_model_chat_completion( 58 | prompt, chat_system_instruction, num_samples 59 | ) 60 | # Skip special tokens and role tokens 61 | generated_text = self.tokenizer.batch_decode( 62 | sequence[:, 1:], skip_special_tokens=True 63 | ) 64 | return generated_text 65 | 66 | def extract_score(self, sequence, logits): 67 | """ 68 | sequence: [batch_size, seq_len] 69 | logits: seq_len x [batch_size, vocab_size] 70 | output: int score 71 | """ 72 | for idx, token_id in enumerate(sequence[0]): 73 | logit = logits[idx][0] 74 | logprobs = F.log_softmax(logit, dim=-1).cpu() 75 | score_logprobs = logprobs[self.score_ids].tolist() 76 | token = self.tokenizer.decode(token_id) 77 | if is_integer_string(token): 78 | return int(token), score_logprobs 79 | print("Failed to extract score") 80 | print(self.tokenizer.batch_decode(sequence)) 81 | return 3, [np.log(0.2)] * 5 82 | 83 | def extract_probs(self, sequence, logits) -> CompareResultObject: 84 | """ 85 | sequence: [batch_size, seq_len] 86 | logits: seq_len x [batch_size, vocab_size] 87 | output: compare_result_object 88 | """ 89 | # First token logit 90 | for idx, token_id in enumerate(sequence[0]): 91 | if token_id in self.A_ids or token_id in self.B_ids: 92 | logit = logits[idx] 93 | probs = F.softmax(logit, dim=-1)[0] 94 | prob_A = sum([probs[a_id].item() for a_id in self.A_ids]) 95 | prob_B = sum([probs[b_id].item() for b_id in self.B_ids]) 96 | prob_C = sum([probs[c_id].item() for c_id in self.C_ids]) 97 | uncertainty = calculate_uncertainty([prob_A, prob_B]) 98 | compare_result = CompareResultObject( 99 | raw_prob_A=prob_A, 100 | raw_prob_B=prob_B, 101 | raw_prob_C=prob_C, 102 | uncertainty=uncertainty, 103 | ) 104 | return compare_result 105 | print("Failed to extract probs") 106 | print(self.tokenizer.batch_decode(sequence)) 107 | return CompareResultObject(raw_prob_A=0.5, raw_prob_B=0.5, uncertainty=1) 108 | 109 | def local_model_chat_completion( 110 | self, prompt, chat_system_instruction=None, num_samples=1 111 | ): 112 | 113 | msg = Llama2ModelLocal.get_chat_message(prompt, chat_system_instruction) 114 | input = self.tokenizer.apply_chat_template( 115 | msg, return_tensors="pt", return_dict=True 116 | ) 117 | 118 | input = {k: v.expand(num_samples, -1) for k, v in input.items()} 119 | input_ids = input["input_ids"].to(device) 120 | 121 | output = self.model.generate( 122 | inputs=input_ids, 123 | return_dict_in_generate=True, 124 | output_logits=True, 125 | max_new_tokens=self.max_tokens, 126 | do_sample=self.do_sample, 127 | temperature=self.temperature, 128 | ) 129 | 130 | newly_generated_tokens = output.sequences[:, input_ids.shape[-1] :] 131 | return newly_generated_tokens, output 132 | 133 | @staticmethod 134 | def get_chat_message(prompt, chat_system_instruction=None): 135 | if chat_system_instruction: 136 | message = [ 137 | {"role": "system", "content": chat_system_instruction}, 138 | {"role": "user", "content": prompt}, 139 | ] 140 | else: 141 | message = [{"role": "user", "content": prompt}] 142 | return message 143 | 144 | 145 | if __name__ == "__main__": 146 | model = Llama2ModelLocal({"model": "meta-llama/Llama-2-7b-chat-hf"}) 147 | -------------------------------------------------------------------------------- /pairs/pairs_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import sys 4 | from .utils import load_summEval, load_newsroom, load_TopicalChat 5 | from .sorting import merge_sort_indices 6 | from .utils import calculate_correlation 7 | from tqdm import tqdm 8 | import os 9 | import pandas as pd 10 | 11 | 12 | def load_matrix(matrix_path): 13 | with open(matrix_path, "r") as f: 14 | matrix = json.load(f) 15 | return matrix 16 | 17 | 18 | def load_rsults(results_path): 19 | with open(results_path, "r") as f: 20 | results = json.load(f) 21 | return results 22 | 23 | 24 | def get_prob_gap(results): 25 | prob_gap_list = [] 26 | norm_prob_gap_list = [] 27 | cal_list = [[], []] 28 | for data_id, id_result in results.items(): 29 | for key, values in id_result.items(): 30 | if key == "prob_list": 31 | for v in values: 32 | prob_gap_list.append(np.abs(v[0] - v[1])) 33 | if key == "norm_prob_list": 34 | for v in values: 35 | norm_prob_gap_list.append(np.abs(v[0] - v[1])) 36 | if key == "logit_list": 37 | for v in values: 38 | cal_list[0].append(v[0]) 39 | cal_list[1].append(v[1]) 40 | base_prior = np.mean([np.mean(cal) for cal in cal_list]) 41 | bc_prior = np.mean([np.abs(base_prior - np.mean(cal)) for cal in cal_list]) 42 | return np.mean(prob_gap_list), np.mean(norm_prob_gap_list), bc_prior 43 | 44 | 45 | def convert_to_small_matrix(matrix): 46 | rows, cols = matrix.shape 47 | mask = np.ones(matrix.shape, dtype=bool) 48 | np.fill_diagonal(mask, False) 49 | return matrix[mask].reshape(rows, cols - 1) 50 | 51 | 52 | def get_a_rate(args, preference_matrixs): 53 | if args.do_permutate: 54 | A_rate_list = [] 55 | for matrix in preference_matrixs: 56 | matrix = np.triu(matrix) 57 | small_matrix = convert_to_small_matrix(np.array(matrix)) 58 | A_count = np.sum(np.logical_and(small_matrix > 0.5, small_matrix != 0)) 59 | B_count = np.sum(np.logical_and(small_matrix < 0.5, small_matrix != 0)) 60 | A_rate = A_count / (A_count + B_count) 61 | A_rate_list.append(A_rate) 62 | else: 63 | A_rate_list = [] 64 | for matrix in preference_matrixs: 65 | small_matrix = convert_to_small_matrix(np.array(matrix)) 66 | A_count = np.sum(small_matrix > 0.5) 67 | B_count = np.sum(small_matrix < 0.5) 68 | A_rate = A_count / (A_count + B_count) 69 | A_rate_list.append(A_rate) 70 | 71 | print("Average A rate:", np.mean(A_rate_list)) 72 | return np.mean(A_rate_list) 73 | 74 | 75 | def compute_spearman(args, preference_matrixs, scores_doc): 76 | params = { 77 | # 'dataset': args.dataset, 78 | # 'engine': "meta-llama/Llama-2-7b-chat-hf", 79 | # 'engine': "meta-llama/Llama-2-13b-chat-hf", 80 | "engine": "mistralai/Mistral-7B-Instruct-v0.1", 81 | # 'engine': 'gpt-3.5-turbo', 82 | # 'aspect': args.aspect, 83 | # 'eval_method': args.eval_method, 84 | "confidence_beam": False, 85 | "beam_size": 1, 86 | "api_call": 0, 87 | "prob_gap": 0.1, 88 | # 'with_input': args.with_input, 89 | # 'calibration': args.calibration, 90 | # 'compare_log': {}, 91 | } 92 | 93 | if args.dataset == "SummEval": 94 | n_candidate = 16 95 | elif args.dataset == "newsroom": 96 | n_candidate = 7 97 | else: 98 | n_candidate = 5 99 | repeat_times = 100 100 | spearman_log = [] 101 | tau_log = [] 102 | total_comparison_log = [] 103 | for _ in range(repeat_times): 104 | spearman_list = [] 105 | tau_list = [] 106 | params["api_call"] = 0 107 | for i, _ in enumerate(preference_matrixs): 108 | ranking_indices = merge_sort_indices( 109 | preference_matrixs[i], params, permutate=args.do_permutate 110 | ) 111 | rho, tau = calculate_correlation( 112 | list(range(n_candidate)), scores_doc[i][ranking_indices] 113 | ) 114 | spearman_list.append(rho) 115 | tau_list.append(tau) 116 | 117 | spearman_log.append(spearman_list) 118 | tau_log.append(tau_list) 119 | total_comparison_log.append(params["api_call"] / len(preference_matrixs)) 120 | 121 | spearman_avg = np.mean(spearman_log) 122 | spearman_std = np.std(np.mean(spearman_log, axis=-1)) 123 | comparison_avg = np.mean(total_comparison_log) 124 | a_rate = get_a_rate(args, preference_matrixs) 125 | print( 126 | f"spearman_avg: {spearman_avg}, spearman_std: {spearman_std}, comparison_avg: {comparison_avg}, A rate: {a_rate}" 127 | ) 128 | return spearman_avg, spearman_std, comparison_avg, a_rate 129 | 130 | 131 | def get_corr_df(args, saving_dir, test_list_id=range(0, 10)): 132 | if args.dataset == "SummEval": 133 | SummEval_path = "./data/model_annotations.aligned.paired.jsonl" 134 | input_doc, output_doc, scores_doc = load_summEval( 135 | SummEval_path, flat_output=False, truncate_num_for_eval=args.eval_data_num 136 | ) 137 | elif args.dataset == "newsroom": 138 | newsroom_path = "./data/newsroom/newsroom.json" 139 | input_doc, output_doc, scores_doc = load_newsroom( 140 | newsroom_path, flat_output=False, truncate_num_for_eval=args.eval_data_num 141 | ) 142 | elif args.dataset == "TopicalChat": 143 | TC_path = "data/topicalchat_usr.json" 144 | input_doc, output_doc, scores_doc = load_TopicalChat( 145 | TC_path, truncate_num_for_eval=args.eval_data_num 146 | ) 147 | scores_doc = np.array(scores_doc[args.aspect_name]) 148 | scores_doc = np.round(scores_doc, 1) 149 | collect_corr = [] 150 | collect_a_rate = [] 151 | collect_norm_prob_gap = [] 152 | collect_bc = [] 153 | for test_id in test_list_id: 154 | print(f"Test id: {test_id}") 155 | test_matrix = load_matrix( 156 | f"{saving_dir}_preference_matrix_log_cot_False_{args.aspect_name}_{test_id}.json" 157 | ) 158 | if args.do_permutate: 159 | test_matrix = [ 160 | (np.array(matrix) + 1 - np.array(matrix).T) / 2 161 | for matrix in test_matrix 162 | ] 163 | results = compute_spearman(args, test_matrix, scores_doc) 164 | collect_corr.append(results[0]) 165 | collect_a_rate.append(results[3]) 166 | prob_results = load_rsults( 167 | f"{saving_dir}_compare_result_log_cot_False_{args.aspect_name}_{test_id}.json" 168 | ) 169 | prob_gap, norm_prob_gap, bc = get_prob_gap(prob_results) 170 | collect_norm_prob_gap.append(norm_prob_gap) 171 | collect_bc.append(bc) 172 | 173 | mod_a = [abs(a - 0.5) for a in collect_a_rate] 174 | df = pd.DataFrame( 175 | { 176 | "Test id": test_list_id, 177 | "Spearman": collect_corr, 178 | "A rate": collect_a_rate, 179 | "Norm Gap": collect_norm_prob_gap, 180 | "BC": collect_bc, 181 | "Fairness": [-1 * a for a in mod_a], 182 | } 183 | ) 184 | df = df.assign(task="PairS") 185 | return df 186 | -------------------------------------------------------------------------------- /prompts.py: -------------------------------------------------------------------------------- 1 | from textwrap import dedent 2 | 3 | 4 | def get_pairwise_prompt_template(dataset, use_instruction=None): 5 | if dataset == "SummEval": 6 | prompt = dedent( 7 | """\ 8 | Source text: {{ input }} 9 | 10 | Summary candidate A: {{ output_1 }} 11 | 12 | Summary candidate B: {{ output_2 }} 13 | 14 | Question: Evaluate and compare the coherence of the two summary candidates for the given source text. \ 15 | Which summary candidate has better coherence? \ 16 | If the candidate A is better, please return 'A'. \ 17 | If the candidate B is better, please return 'B'. \ 18 | You must return the choice only. 19 | Answer: """ 20 | ) 21 | if use_instruction: 22 | prompt = dedent( 23 | """\ 24 | Source text: {{ input }} 25 | 26 | Summary candidate A: {{ output_1 }} 27 | 28 | Summary candidate B: {{ output_2 }} 29 | 30 | Question: {{ instruction }} 31 | Answer: """ 32 | ) 33 | 34 | elif dataset == "newsroom": 35 | prompt = dedent( 36 | """\ 37 | Source text: {{ input }} 38 | 39 | Summary candidate A: {{ output_1 }} 40 | 41 | Summary candidate B: {{ output_2 }} 42 | 43 | Question: Evaluate and compare the coherence of the two summary candidates for the given source text. \ 44 | Which summary candidate has better coherence? \ 45 | If the candidate A is better, please return 'A'. \ 46 | If the candidate B is better, please return 'B'. \ 47 | You must return the choice only. 48 | Answer: """ 49 | ) 50 | if use_instruction: 51 | prompt = dedent( 52 | """\ 53 | Source text: {{ input }} 54 | 55 | Summary candidate A: {{ output_1 }} 56 | 57 | Summary candidate B: {{ output_2 }} 58 | 59 | Question: {{ instruction }} 60 | Answer: """ 61 | ) 62 | 63 | elif dataset == "TopicalChat": 64 | prompt = dedent( 65 | """\ 66 | Dialog history: 67 | {{ input }} 68 | 69 | Response candidate A: {{ output_1 }} 70 | Response candidate B: {{ output_2 }} 71 | 72 | Question: Which response is overall better for the given dialog history? \ 73 | Please consider aspects including naturalness, understandability, context consistency and knowledge richness. \ 74 | If the candidate A is better, please return 'A'. \ 75 | If the candidate B is better, please return 'B'. \ 76 | You must return the choice only. 77 | Answer: """ 78 | ) 79 | if use_instruction: 80 | prompt = dedent( 81 | """\ 82 | Dialog history: {{ input }} 83 | 84 | Response candidate A: {{ output_1 }} 85 | Response candidate B: {{ output_2 }} 86 | 87 | Question: {{ instruction }} 88 | Answer: """ 89 | ) 90 | 91 | elif dataset == "GSM8k": 92 | prompt = dedent( 93 | """\ 94 | Math question: {{ input }} 95 | 96 | Solution candidate A: {{ output_1 }} 97 | 98 | Solution candidate B: {{ output_2 }} 99 | 100 | Instruction: Compare the quality of the two solution candidates for the given math question. \ 101 | Which solution candidate is better explained and more logical? \ 102 | If the candidate A is better, please return 'A'. \ 103 | If the candidate B is better, please return 'B'. \ 104 | You must only return your choice and make no explanation. 105 | Answer: """ 106 | ) 107 | 108 | else: 109 | assert False, f"Invalid dataset: {dataset}" 110 | 111 | return prompt 112 | 113 | 114 | def get_pointwise_prompt_template(dataset, with_input): 115 | if with_input: 116 | prompt = dedent( 117 | """\ 118 | Evaluate the overall quality of the following output candidate for the given input. 119 | 120 | Input: {{ input }} 121 | 122 | Output candidate: {{ output }} 123 | 124 | Question: How would you rate the overall quality of the output candidate? \ 125 | Please provide a score between 1 and 10. \ 126 | You must return the score only. 127 | Answer: """ 128 | ) 129 | else: 130 | prompt = dedent( 131 | """\ 132 | Evaluate the overall quality of the following output candidate. 133 | 134 | Output candidate: {{ output }} 135 | 136 | Question: How would you rate the overall quality of the output candidate? \ 137 | Please provide a score between 1 and 10. \ 138 | You must return the score only. 139 | Answer: """ 140 | ) 141 | return prompt 142 | 143 | 144 | def get_cot_compare_prompt_template(dataset): 145 | if dataset == "SummEval": 146 | prompt = dedent( 147 | """\ 148 | Source text: {{ input }} 149 | 150 | Summary candidate A: {{ output_1 }} 151 | 152 | Summary candidate B: {{ output_2 }} 153 | 154 | Instruction: Please briefly analyse and compare the coherence of the two summary candidates for the given source text, \ 155 | and then conclude which candidate is more coherent.""" 156 | ) 157 | 158 | elif dataset == "TopicalChat": 159 | prompt = dedent( 160 | """\ 161 | Dialog history: 162 | {{ input }} 163 | 164 | Response candidate A: {{ output_1 }} 165 | Response candidate B: {{ output_2 }} 166 | 167 | Question: Which response is overall better for the given dialog history? \ 168 | Please consider aspects including naturalness, understandability, context consistency and knowledge richness. \ 169 | If the candidate A is better, please return 'A'. \ 170 | If the candidate B is better, please return 'B'. \ 171 | You must return the choice only. 172 | Answer: """ 173 | ) 174 | 175 | elif dataset == "GSM8k": 176 | prompt = dedent( 177 | """\ 178 | Math question: {{ input }} 179 | 180 | Solution candidate A: {{ output_1 }} 181 | 182 | Solution candidate B: {{ output_2 }} 183 | 184 | Instruction: Analyse and compare the quality of the two solution candidates for the given math question. \ 185 | Please briefly discuss the strengths and weaknesses of both solution candidates and conclude which is more logical and correct?""" 186 | ) 187 | 188 | else: 189 | assert False, f"Invalid dataset: {dataset}" 190 | return prompt 191 | 192 | 193 | def get_cot_eval_prompt_template(): 194 | prompt = dedent( 195 | """\ 196 | {{ cot_response}} 197 | 198 | Based on the above evaluation, which candidate is preferred according to the analysis? 199 | If the candidate A is preferred, please return 'A'. \ 200 | If the candidate B is preferred, please return 'B'. \ 201 | If both candidates are equally preferred, please return 'C'. \ 202 | You must return the choice only. 203 | Answer: """ 204 | ) 205 | 206 | return prompt 207 | -------------------------------------------------------------------------------- /zepo.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from models.openai_api import OpenAIChatModel 3 | from pairwise_comparison import pairwise_compare 4 | import pairs 5 | import argparse 6 | import wandb 7 | import numpy as np 8 | import json 9 | import os 10 | 11 | 12 | openai_api_key = os.environ.get("OPENAI_API_KEY") 13 | 14 | 15 | def get_instruction(args: argparse.Namespace, _instruction: str, iteration: int = 5) -> List[str]: 16 | """ 17 | Generate paraphrased instructions using OpenAI API. 18 | 19 | Args: 20 | args: command line arguments including aspect name. 21 | _instruction: initial instruction to paraphrase. 22 | iteration: number of paraphrased instructions to generate. 23 | 24 | Returns: 25 | List of paraphrased instructions. 26 | """ 27 | example_prompt = f"""\ 28 | Paraphrase the following instruction for a pairwise comparison task. Do not change the keyword "{args.aspect_name}". Be diverse and creative in paraphrasing. Return the instruction only. \ 29 | 30 | Input: {_instruction}\ 31 | 32 | Output: 33 | """ 34 | model = OpenAIChatModel( 35 | {"engine": "gpt-3.5-turbo", "temperature": 0.9}, api_key=openai_api_key 36 | ) 37 | prompts = [example_prompt] * iteration 38 | results = model.generate(prompts) 39 | return results 40 | 41 | 42 | def zepo(args: argparse.Namespace): 43 | # Initialize variables 44 | optimize_metric = "Fairness" 45 | 46 | init_instruction_dict = json.load(open("init_prompts.json")) 47 | 48 | init_instruction = init_instruction_dict[args.aspect_name] 49 | instruction_set = get_instruction( 50 | args, init_instruction, iteration=args.sample_num - 1 51 | ) 52 | instruction_set = [init_instruction] + instruction_set 53 | collect_instruction = [] 54 | collect_instruction += instruction_set 55 | collect_results = {} 56 | log_new_instruction = [] 57 | print("Initial instruction set: ", instruction_set) 58 | best_metric = -99 59 | best_corr = 0 60 | wandb.init( 61 | project="zepo", 62 | config={ 63 | "dataset": args.dataset, 64 | "aspect_name": args.aspect_name, 65 | "engine": args.engine, 66 | "batch_size": args.batch_size, 67 | "sample_num": args.sample_num, 68 | "eval_data_num": args.eval_data_num, 69 | "epoch_num": args.epoch_num, 70 | "instruction": init_instruction, 71 | "instruction_set": instruction_set, 72 | "best_metric": best_metric, 73 | "best_corr": best_corr, 74 | }, 75 | ) 76 | 77 | # Optimize instructions over multiple epochs 78 | for epoch in range(args.epoch_num): 79 | args.saving_dir = f"results/{args.engine}/permutate_{args.do_permutate}/{args.dataset}/{args.aspect_name}/{epoch}/" 80 | 81 | # Evaluation instructions in pairwise comparisons 82 | pairwise_compare(args, instruction_set, round_id=epoch) 83 | saving_dir = f"results/{args.engine}/permutate_{args.do_permutate}/{args.dataset}/{args.aspect_name}/{epoch}/" 84 | saving_path = f"{saving_dir}{args.engine.split('/')[-1]}" 85 | 86 | # Retrieve fairness 87 | df = pairs.pairs_eval.get_corr_df( 88 | args, saving_path, test_list_id=range(0, args.sample_num) 89 | ) 90 | df.to_csv(f"{saving_dir}{args.engine.split('/')[-1]}_results.csv") 91 | print(df) 92 | best_id = df[optimize_metric].idxmax() 93 | new_metric = df[optimize_metric].max() 94 | 95 | # Greedy selection of the best instruction 96 | if new_metric > best_metric: 97 | best_metric = new_metric 98 | new_instruction = instruction_set[best_id] 99 | log_new_instruction.append(new_instruction) 100 | wandb.log({}) 101 | print(f"Best instruction: ", new_instruction) 102 | new_corr = df["Spearman"][best_id] 103 | if new_corr > best_corr: 104 | best_corr = new_corr 105 | print("Best Correlation: ", best_corr) 106 | 107 | # Generate new set of instructions for the next epoch 108 | if epoch != args.epoch_num - 1: 109 | instruction_set = get_instruction( 110 | args, new_instruction, iteration=args.sample_num 111 | ) 112 | print(f"New instruction set at epoch {epoch+1}: ", instruction_set) 113 | collect_instruction += instruction_set 114 | 115 | wandb.log( 116 | { 117 | "best_corr": best_corr, 118 | "best_metric": best_metric, 119 | "instruction": new_instruction, 120 | "instruction_set": instruction_set, 121 | "epoch": epoch, 122 | } 123 | ) 124 | 125 | # evaluate the final instruction 126 | print("Final instruction: ", new_instruction) 127 | args.eval_data_num = 100 128 | epoch = "final" 129 | args.saving_dir = f"results/{args.engine}/permutate_{args.do_permutate}/{args.dataset}/{args.aspect_name}/{epoch}/" 130 | pairwise_compare(args, [new_instruction], round_id=epoch) 131 | saving_path = f"{args.saving_dir}{args.engine.split('/')[-1]}" 132 | df = pairs.pairs_eval.get_corr_df(args, saving_path, test_list_id=[0]) 133 | best_id = df[optimize_metric].idxmax() 134 | print(f"Best instruction id: {best_id}") 135 | best_corr = df["Spearman"][best_id] 136 | print("Best Correlation: ", best_corr) 137 | collect_results["test corr"] = best_corr 138 | collect_results["instruction set"] = collect_instruction 139 | collect_results["final instruction"] = new_instruction 140 | collect_results["log best instruction"] = log_new_instruction 141 | 142 | # compare with the initial instruction 143 | epoch = "init" 144 | args.saving_dir = f"results/{args.engine}/permutate_{args.do_permutate}/{args.dataset}/{args.aspect_name}/{epoch}/" 145 | pairwise_compare(args, [init_instruction], round_id=epoch) 146 | saving_path = f"{args.saving_dir}{args.engine.split('/')[-1]}" 147 | df = pairs.pairs_eval.get_corr_df(args, saving_path, test_list_id=[0]) 148 | best_id = df[optimize_metric].idxmax() 149 | best_corr = df["Spearman"][best_id] 150 | print("Init Correlation: ", best_corr) 151 | collect_results["init corr"] = best_corr 152 | 153 | # Save the file 154 | saving_path = f"{args.saving_dir}{args.engine.split('/')[-1]}_{args.aspect_name}_{args.eval_data_num}_{args.sample_num}_{args.epoch_num}_results.json" 155 | with open(saving_path, "w") as f: 156 | json.dump(collect_results, f, indent=4) 157 | f.close() 158 | 159 | 160 | if __name__ == "__main__": 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument("--dataset", type=str, default="SummEval") 163 | parser.add_argument("--aspect_name", type=str, default="coherence") 164 | parser.add_argument( 165 | "--engine", type=str, default="mistralai/Mistral-7B-Instruct-v0.1" 166 | ) 167 | parser.add_argument("--batch_size", type=int, default=6) 168 | parser.add_argument("--worker_num", type=int, default=1) 169 | parser.add_argument("--sample_num", type=int, default=5) 170 | parser.add_argument("--eval_data_num", type=int, default=5) 171 | parser.add_argument("--epoch_num", type=int, default=5) 172 | parser.add_argument("--do_cot", action="store_true", default=False) 173 | parser.add_argument("--do_permutate", action="store_true", default=False) 174 | parser.add_argument("--saving_dir", type=str, default="results/") 175 | 176 | args = parser.parse_args() 177 | zepo(args) 178 | -------------------------------------------------------------------------------- /models/openai_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from openai import OpenAI 4 | from tqdm import tqdm 5 | from concurrent.futures import ThreadPoolExecutor 6 | import datetime 7 | import numpy as np 8 | 9 | 10 | openai_api_key = os.environ.get("OPENAI_API_KEY") 11 | 12 | 13 | class Timer(object): 14 | def __init__(self): 15 | self.__start = time.time() 16 | 17 | def start(self): 18 | self.__start = time.time() 19 | 20 | def get_time(self, restart=True, format=False): 21 | end = time.time() 22 | span = end - self.__start 23 | if restart: 24 | self.__start = end 25 | if format: 26 | return self.format(span) 27 | else: 28 | return span 29 | 30 | def format(self, seconds): 31 | return datetime.timedelta(seconds=int(seconds)) 32 | 33 | def print(self, name): 34 | print(name, self.get_time()) 35 | 36 | 37 | class OpenAIChatModel: 38 | def __init__(self, params={}, api_key=None): 39 | self.api_key = api_key 40 | if "engine" not in params: 41 | params["engine"] = "gpt-3.5-turbo" 42 | if "temperature" not in params: 43 | params["temperature"] = 0 44 | if "max_tokens" not in params: 45 | params["max_tokens"] = 512 46 | if "logprobs" not in params: 47 | params["logprobs"] = True 48 | if "top_logprobs" not in params: 49 | params["top_logprobs"] = 5 50 | if "attempt_num" not in params: 51 | params["attempt_num"] = 10 52 | if "do_sample" not in params: 53 | params["do_sample"] = False 54 | if "top_p" not in params: 55 | params["top_p"] = 1 56 | if "chat_system_instruction" not in params: 57 | params["chat_system_instruction"] = None 58 | 59 | self.params = params 60 | if not api_key: 61 | api_key = os.getenv("OPENAI_API_KEY") 62 | self.client = OpenAI(api_key=api_key) 63 | 64 | def generate(self, prompts, chat_system_instruction=None, max_workers=4): 65 | self.params["chat_system_instruction"] = chat_system_instruction 66 | 67 | response_list = self.multi_threading_openai_chat_completion( 68 | prompts, self.single_openai_api_call_generate, max_workers=max_workers 69 | ) 70 | return response_list 71 | 72 | def call_openai_chat_completion(self, prompt): 73 | if self.params["chat_system_instruction"]: 74 | msg = [ 75 | {"role": "system", "content": self.params["chat_system_instruction"]} 76 | ] 77 | else: 78 | msg = [] 79 | msg.append({"role": "user", "content": prompt}) 80 | attempt = 0 81 | while True: 82 | try: 83 | response = self.client.chat.completions.create( 84 | model=self.params["engine"], 85 | messages=msg, 86 | temperature=self.params["temperature"], 87 | max_tokens=self.params["max_tokens"], 88 | logprobs=self.params["logprobs"], 89 | top_logprobs=( 90 | self.params["top_logprobs"] if self.params["logprobs"] else None 91 | ), 92 | ) 93 | return response 94 | 95 | except Exception as e: 96 | print(e) 97 | print(response) 98 | attempt += 1 99 | if attempt >= self.params["attempt_num"]: 100 | return None 101 | wait_sec = 1 102 | time.sleep(wait_sec) 103 | 104 | def multi_threading_openai_chat_completion( 105 | self, prompts, single_thread_func_handler, max_workers=4 106 | ): 107 | inputs = [{"prompt": prompt} for prompt in prompts] 108 | timer = Timer() 109 | print(f"using model_{self.params['engine']}") 110 | print("Processing queires") 111 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 112 | futures = list( 113 | tqdm( 114 | executor.map(lambda x: single_thread_func_handler(x), inputs), 115 | total=len(prompts), 116 | ) 117 | ) 118 | print( 119 | "Average time after {0} samples: {1}".format( 120 | len(prompts), timer.get_time(restart=False) / len(prompts) 121 | ) 122 | ) 123 | print("Processed queries") 124 | 125 | result_list = [input["result"] for input in inputs] 126 | return result_list 127 | 128 | def single_openai_api_call_generate(self, input): 129 | response = self.call_openai_chat_completion(input["prompt"]) 130 | result = response.choices[0].message.content 131 | input["result"] = result 132 | 133 | 134 | if __name__ == "__main__": 135 | 136 | example_prompt = """\ 137 | Evaluate and compare the coherence of the two following summary candidates for the given input source text. 138 | 139 | Input source text: Paul Merson has restarted his row with Andros Townsend after the Tottenham midfielder was brought on with only seven minutes remaining in his team's 0-0 draw with Burnley on Sunday. 'Just been watching the game, did you miss the coach? #RubberDub #7minutes,' Merson put on Twitter. Merson initially angered Townsend for writing in his Sky Sports column that 'if Andros Townsend can get in (the England team) then it opens it up to anybody.' Paul Merson had another dig at Andros Townsend after his appearance for Tottenham against Burnley Townsend was brought on in the 83rd minute for Tottenham as they drew 0-0 against Burnley Andros Townsend scores England's equaliser in their 1-1 friendly draw with Italy in Turin on Tuesday night The former Arsenal man was proven wrong when Townsend hit a stunning equaliser for England against Italy and he duly admitted his mistake. 'It's not as though I was watching hoping he wouldn't score for England, I'm genuinely pleased for him and fair play to him – it was a great goal,' Merson said. 'It's just a matter of opinion, and my opinion was that he got pulled off after half an hour at Manchester United in front of Roy Hodgson, so he shouldn't have been in the squad. 'When I'm wrong, I hold my hands up. I don't have a problem with doing that - I'll always be the first to admit when I'm wrong.' Townsend hit back at Merson on Twitter after scoring for England against Italy Sky Sports pundit Merson (centre) criticised Townsend's call-up to the England squad last week Townsend hit back at Merson after netting for England in Turin on Wednesday, saying 'Not bad for a player that should be 'nowhere near the squad' ay @PaulMerse?' Any bad feeling between the pair seemed to have passed but Merson was unable to resist having another dig at Townsend after Tottenham drew at Turf Moor. 140 | 141 | Compare the following outputs: 142 | 143 | Summary candidate A: paul merson was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . andros townsend scored the tottenham midfielder in the 89th minute . paul merson had another dig at andros townsend after his appearance . the midfielder had been brought on to the england squad last week . click here for all the latest arsenal news news . 144 | 145 | Summary candidate B: paul merson has restarted his row with andros townsend . the tottenham midfielder was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . andros townsend scores england 's equaliser in their 1-1 friendly draw with italy in turin . 146 | 147 | Question: Which summary candidate has better coherence? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only. 148 | Answer: \ 149 | """ 150 | 151 | prompts = [example_prompt] * 3 152 | model = OpenAIChatModel({"engine": "gpt-3.5-turbo"}, api_key=openai_api_key) 153 | result = model.generate(prompts) 154 | print(result) 155 | -------------------------------------------------------------------------------- /pairs/llama2.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaForCausalLM, AutoTokenizer 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from utils import CompareResultObject, calculate_uncertainty 6 | 7 | 8 | device = "cuda" 9 | 10 | 11 | def is_integer_string(s): 12 | return s.isdigit() 13 | 14 | 15 | class Llama2ModelLocal: 16 | def __init__(self, params): 17 | self.model_name = params["model"] 18 | self.device = device 19 | if "cache_dir" not in params: 20 | params["cache_dir"] = None 21 | self.tokenizer = AutoTokenizer.from_pretrained( 22 | self.model_name, padding_side="left", cache_dir=params["cache_dir"] 23 | ) # base_model 24 | self.tokenizer.pad_token = self.tokenizer.eos_token 25 | load_in_8bit = True if "13" in self.model_name else False 26 | self.model = LlamaForCausalLM.from_pretrained( 27 | self.model_name, 28 | load_in_8bit=load_in_8bit, 29 | torch_dtype=torch.bfloat16, 30 | device_map=self.device, 31 | cache_dir=params["cache_dir"], 32 | attn_implementation="flash_attention_2", 33 | ) 34 | self.model.eval() 35 | self.A_ids = self.tokenizer.convert_tokens_to_ids(["A", "▁A"]) # A: 330 36 | self.B_ids = self.tokenizer.convert_tokens_to_ids(["B", "▁B"]) # B: 365 37 | self.C_ids = self.tokenizer.convert_tokens_to_ids(["C", "▁C"]) # C: 38 | self.score_ids = self.tokenizer.convert_tokens_to_ids(["1", "2", "3", "4", "5"]) 39 | 40 | def rate_score(self, prompts): 41 | sequence, output = self.local_model_chat_completion(prompts) 42 | # print(self.tokenizer.batch_decode(sequence)) 43 | # score, logprobs = self.extract_score(sequence, output.logits) 44 | scores, logprobs = [], [] 45 | for idx in range(sequence.shape[0]): 46 | seq_logits = [ 47 | logits[idx] for logits in output.logits 48 | ] # convert to [seq_len, vocab_size] 49 | score, logprob = self.extract_score(sequence[idx], seq_logits) 50 | scores.append(score) 51 | logprobs.append(logprob) 52 | return scores, logprobs 53 | 54 | def compare(self, prompts): 55 | """ 56 | prompts: [batch_size, seq_len] 57 | output: a list of compare_result_object, [batch_size] 58 | """ 59 | sequence, output = self.local_model_chat_completion(prompts) 60 | compare_results = [] 61 | for idx in range(sequence.shape[0]): 62 | seq_logits = [ 63 | logits[idx] for logits in output.logits 64 | ] # convert to [seq_len, vocab_size] 65 | compare_result = self.extract_probs(sequence[idx], seq_logits) 66 | compare_results.append(compare_result) 67 | return compare_results 68 | 69 | def extract_score(self, sequence, logits): 70 | """ 71 | sequence: [batch_size, seq_len] 72 | logits: seq_len x [batch_size, vocab_size] 73 | output: int score 74 | """ 75 | for idx, token_id in enumerate(sequence): 76 | logit = logits[idx] 77 | logprobs = F.log_softmax(logit, dim=-1).cpu() 78 | score_logprobs = logprobs[self.score_ids].tolist() 79 | token = self.tokenizer.decode(token_id) 80 | if is_integer_string(token): 81 | return int(token), score_logprobs 82 | print("Failed to extract score") 83 | print(self.tokenizer.batch_decode(sequence)) 84 | return 3, [np.log(0.2)] * 5 85 | 86 | def extract_probs(self, sequence, logits) -> CompareResultObject: 87 | """ 88 | sequence: [batch_size, seq_len] 89 | logits: seq_len x [batch_size, vocab_size] 90 | output: compare_result_object 91 | """ 92 | # First token logit 93 | # print(self.tokenizer.batch_decode(sequence)) 94 | # print(self.tokenizer.batch_decode(sequence)) 95 | for idx, token_id in enumerate(sequence): 96 | if token_id in self.A_ids or token_id in self.B_ids: 97 | logit = logits[idx] 98 | probs = F.softmax(logit, dim=-1) 99 | prob_A = sum([probs[a_id].item() for a_id in self.A_ids]) 100 | prob_B = sum([probs[b_id].item() for b_id in self.B_ids]) 101 | prob_C = sum([probs[c_id].item() for c_id in self.C_ids]) 102 | uncertainty = calculate_uncertainty([prob_A, prob_B]) 103 | compare_result = CompareResultObject( 104 | raw_prob_A=prob_A, 105 | raw_prob_B=prob_B, 106 | raw_prob_C=prob_C, 107 | uncertainty=uncertainty, 108 | ) 109 | return compare_result 110 | print("Failed to extract probs") 111 | print(self.tokenizer.batch_decode([sequence])) 112 | return CompareResultObject(raw_prob_A=0.5, raw_prob_B=0.5, uncertainty=1) 113 | 114 | # def local_model_chat_completion(self, prompt): 115 | # msg = Llama2ModelLocal.get_chat_message(prompt) 116 | # input = self.tokenizer.apply_chat_template(msg, return_tensors="pt", return_dict=True) 117 | # # input = self.tokenizer.apply_chat_template(msg, tokenize=False) 118 | # # input = self.tokenizer(input, return_tensors="pt", return_dict=True) 119 | # input = input.to(device) 120 | 121 | # output = self.model.generate( 122 | # inputs=input.input_ids, 123 | # return_dict_in_generate=True, 124 | # output_logits=True, 125 | # max_new_tokens=32, 126 | # do_sample=False, 127 | # temperature=None, 128 | # top_p=None 129 | # ) 130 | 131 | # newly_generated_tokens = output.sequences[:,input.input_ids.shape[1]:] 132 | # return newly_generated_tokens, output 133 | 134 | def local_model_chat_completion( 135 | self, prompts, chat_system_instruction=None, num_samples=1 136 | ): 137 | # if num_samples>1: 138 | # prompts = [prompts]*num_samples 139 | messages = [] 140 | for prompt in prompts: 141 | msg = Llama2ModelLocal.get_chat_message(prompt, chat_system_instruction) 142 | msg = self.tokenizer.apply_chat_template( 143 | msg, tokenize=False 144 | ) # return_tensors="pt", return_dict=True) 145 | messages.append(msg) 146 | 147 | input = self.tokenizer(messages, return_tensors="pt", padding=True) 148 | input = input.to(device) 149 | output = self.model.generate( 150 | **input, 151 | return_dict_in_generate=True, 152 | output_logits=True, 153 | max_new_tokens=32, 154 | do_sample=False, 155 | temperature=None, 156 | ) 157 | 158 | newly_generated_tokens = output.sequences[:, input.input_ids.shape[-1] :] 159 | return newly_generated_tokens, output 160 | 161 | @staticmethod 162 | def get_chat_message(prompt, chat_system_instruction=None): 163 | if chat_system_instruction: 164 | message = [ 165 | {"role": "system", "content": chat_system_instruction}, 166 | {"role": "user", "content": prompt}, 167 | ] 168 | else: 169 | message = [{"role": "user", "content": prompt}] 170 | return message 171 | 172 | 173 | if __name__ == "__main__": 174 | 175 | params = { 176 | "model": "meta-llama/Llama-2-7b-chat-hf", 177 | "cache_dir": None, 178 | "eval_size": 50, 179 | "template": "score", 180 | "aspect": "coherence", 181 | "dataset": "SumEval", 182 | "with_input": True, 183 | } 184 | 185 | prompt = "The quick brown fox jumps over the lazy dog." 186 | model = Llama2ModelLocal(params) 187 | 188 | score, logprobs = model.rate_score(prompt) 189 | -------------------------------------------------------------------------------- /pairs/full_preference_matrix.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | from utils import load_summEval 7 | from prompts import get_prompt_template, get_aspect_instruction 8 | import json 9 | import torch 10 | from utils import load_summEval 11 | from llama2 import Llama2ModelLocal 12 | from mistral import MistralModelLocal 13 | from openai_api import OpenAIChatModel 14 | from jinja2 import Environment 15 | 16 | 17 | 18 | def pairwise_non_diagonal_to_list(size): 19 | ''' 20 | Convert a Square pairwise matrix to a list of non-diagonal elements 21 | ''' 22 | rows, cols = size, size 23 | result = [] 24 | for i in range(rows): 25 | for j in range(cols): 26 | if i != j: # Check if the element is not on the diagonal 27 | result.append((i, j)) 28 | return result 29 | 30 | 31 | def list_to_pairwise(non_diagonal_list, size, include_diagonal=False): 32 | ''' 33 | Convert a list of non-diagonal elements back to a pairwise matrix 34 | Square matrix 35 | ''' 36 | rows, cols = size, size 37 | matrix = [[0] * cols for _ in range(rows)] 38 | index = 0 39 | for i in range(rows): 40 | for j in range(cols): 41 | if i != j or include_diagonal: 42 | # if i != j: # Check if the element is not on the diagonal 43 | matrix[i][j] = non_diagonal_list[index] 44 | index += 1 45 | return np.array(matrix) 46 | 47 | 48 | def compute_pairwise_preference_matrix(model, input, output, prompt_template, worker_num=1, batch_size=1): 49 | ''' 50 | worker_num is for closed source models async parallel processing 51 | batch_size is for open source models parallel processing 52 | ''' 53 | task_instruction = get_aspect_instruction('coherence', eval_method='pairwise comparison', dataset='SummEval') 54 | 55 | response_size = len(output) 56 | full_pairwise_list = pairwise_non_diagonal_to_list(size=response_size) 57 | mask_matrix = np.ones((response_size, response_size)) 58 | mask_matrix = np.triu(mask_matrix, k=1).astype(bool) 59 | mask_matrix_list = mask_matrix.flatten().tolist() 60 | 61 | prompts = [] 62 | for pair, mask in zip(full_pairwise_list, mask_matrix_list): 63 | # if mask == False: 64 | # continue 65 | prompt = prompt_template.render( 66 | instruction=task_instruction, 67 | input=input, 68 | output_1=output[pair[0]], 69 | output_2=output[pair[1]], 70 | aspect="coherence", 71 | ) 72 | prompts.append(prompt) 73 | # print(len(prompts)) 74 | # assert False 75 | 76 | if worker_num > 1: # Parallel processing for closed source models 77 | compare_result_list = model.compare(prompts) 78 | else: 79 | compare_result_list = [] 80 | for i in tqdm(range(0, len(prompts), batch_size)): 81 | batch_prompts = prompts[i:i+batch_size] 82 | compare_results = model.compare(batch_prompts) 83 | compare_result_list.extend(compare_results) 84 | 85 | pairwise_preference_list = [compare_result.prob_A for compare_result in compare_result_list] 86 | 87 | 88 | # Count if prob is balanced 89 | A_cnt,B_cnt,C_cnt = 0,0,0 90 | for prob in pairwise_preference_list: 91 | if prob > 0.5: 92 | A_cnt +=1 93 | elif prob < 0.5: 94 | B_cnt +=1 95 | else: 96 | C_cnt +=1 97 | print('A_cnt:', A_cnt, 'B_cnt:', B_cnt, 'C_cnt:', C_cnt) 98 | 99 | # prompt_idx = 0 100 | # full_pairwise_preference_list = [] 101 | # for mask in mask_matrix_list: 102 | # if mask == True: 103 | # full_pairwise_preference_list.append(pairwise_preference_list[prompt_idx]) 104 | # prompt_idx += 1 105 | # else: 106 | # full_pairwise_preference_list.append(0) 107 | # pairwise_preference_list = full_pairwise_preference_list 108 | 109 | pairwise_preference_matrix = list_to_pairwise(pairwise_preference_list, response_size, include_diagonal=False) 110 | return pairwise_preference_matrix 111 | 112 | 113 | def main(args): 114 | if args.dataset == 'SummEval': 115 | SummEval_path = '/home/yinhong/Documents/source/PairS/data/SummEval/model_annotations.aligned.paired.jsonl' 116 | input_doc, output_doc, scores = load_summEval(SummEval_path, flat_output=False) 117 | scores = scores['coherence'] 118 | 119 | # Load model 120 | if 'mistral' in args.engine: 121 | model = MistralModelLocal({'model': args.engine}) 122 | # elif 'Llama-3' in args.engine: 123 | # model = Llama3ModelLocal({'model': args.engine, 'cot': args.do_cot}) 124 | elif 'Llama-2' in args.engine: 125 | model = Llama2ModelLocal({'model': args.engine}) 126 | elif 'gpt' in args.engine: 127 | model = OpenAIChatModel({'model': args.engine}) 128 | 129 | prompt_template = get_prompt_template( 130 | prompt_name="pairwise comparison", 131 | aspect='coherence', 132 | dataset=args.dataset, 133 | model_name=None, 134 | with_input=True 135 | ) 136 | environment = Environment() 137 | prompt_template = environment.from_string(prompt_template) 138 | 139 | 140 | # intrans_list, trans_list = [], [] 141 | pairwise_preference_matrix_log = [] 142 | # preference_matrix_list = [] 143 | for i in range(len(input_doc)): 144 | print('Data point:', i+1, 'out of', len(input_doc), 'data points.') 145 | input = input_doc[i][0] 146 | output = output_doc[i] 147 | # score = scores[i] 148 | pairwise_preference_matrix = compute_pairwise_preference_matrix( 149 | model, 150 | input, 151 | output, 152 | prompt_template, 153 | worker_num=args.worker_num, 154 | batch_size=args.batch_size 155 | ) 156 | pairwise_preference_matrix_log.append(pairwise_preference_matrix.tolist()) 157 | 158 | 159 | # Release the model from GPU 160 | del model 161 | torch.cuda.empty_cache() 162 | 163 | # Save list of matrices to JSONL file 164 | with open(f'{args.engine.split('/')[-1]}_preference_matrix_log.json', 'w') as f: 165 | json.dump(pairwise_preference_matrix_log, f) 166 | f.close() 167 | 168 | 169 | if __name__ == '__main__': 170 | import argparse 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument('--dataset', type=str, default='SummEval') 173 | parser.add_argument('--engine', type=str, default='mistralai/Mistral-7B-Instruct-v0.1') 174 | parser.add_argument('--worker_num', type=int, default=1) 175 | parser.add_argument('--batch_size', type=int, default=1) 176 | # parser.add_argument('--do_cot', action='store_true') 177 | args = parser.parse_args() 178 | 179 | # with open(f'{args.engine.split('/')[-1]}_preference_matrix_log.json', 'w') as f: 180 | # pass 181 | # f.close() 182 | 183 | if args.engine == 'full': 184 | engine_list = [ 185 | # 'mistralai/Mistral-7B-Instruct-v0.1', 186 | # 'meta-llama/Llama-2-7b-chat-hf', 187 | # 'meta-llama/Llama-2-13b-chat-hf', 188 | # 'meta-llama/Meta-Llama-3-8B-Instruct', 189 | 'gpt-3.5-turbo' 190 | ] 191 | else: 192 | engine_list = [args.engine] 193 | 194 | results_to_report = [] 195 | for engine in engine_list: 196 | args.engine = engine 197 | print('Engine: ', engine) 198 | if 'gpt' in engine: 199 | args.worker_num = 6 200 | else: 201 | args.worker_num = 1 202 | # trans, intrans = 203 | main(args) 204 | # report = { 205 | # 'engine': engine, 206 | # 'dataset': args.dataset, 207 | # 'transitivity': trans, 208 | # 'intransitivity': intrans 209 | # } 210 | # results_to_report.append(report) 211 | print('==========================================================') 212 | 213 | # for report in results_to_report: 214 | # print(report) -------------------------------------------------------------------------------- /models/mistral.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | sys.path.append("../") 8 | from utils import CompareResultObject, calculate_uncertainty 9 | 10 | 11 | device = "cuda" 12 | 13 | 14 | def is_integer_string(s): 15 | return s.isdigit() 16 | 17 | 18 | class MistralModelLocal: 19 | def __init__(self, params): 20 | self.model_name = params["model"] 21 | self.temperature = params["temperature"] if "temperature" in params else 0 22 | self.max_tokens = params["max_tokens"] if "max_tokens" in params else 128 23 | self.do_sample = params["do_sample"] if "do_sample" in params else False 24 | self.device = device 25 | self.model = AutoModelForCausalLM.from_pretrained( 26 | self.model_name, 27 | device_map=self.device, 28 | attn_implementation="flash_attention_2", # flash attention is not easy to install 29 | torch_dtype=torch.bfloat16, 30 | ) 31 | self.tokenizer = AutoTokenizer.from_pretrained( 32 | self.model_name 33 | ) # , cache_dir="models") 34 | self.tokenizer.pad_token = self.tokenizer.eos_token 35 | self.A_ids = self.tokenizer.convert_tokens_to_ids(["A", "▁A"]) # A: 330 36 | self.B_ids = self.tokenizer.convert_tokens_to_ids(["B", "▁B"]) # B: 365 37 | self.C_ids = self.tokenizer.convert_tokens_to_ids(["C", "▁C"]) # C: 38 | self.score_ids = self.tokenizer.convert_tokens_to_ids(["1", "2", "3", "4", "5"]) 39 | 40 | def compare(self, prompts): 41 | sequence, output = self.local_model_chat_completion(prompts) 42 | compare_results = [] 43 | for idx in range(sequence.shape[0]): 44 | seq_logits = [ 45 | logits[idx] for logits in output.logits 46 | ] # convert to [seq_len, vocab_size] 47 | compare_result = self.extract_probs(sequence[idx], seq_logits) 48 | compare_results.append(compare_result) 49 | return compare_results 50 | 51 | def generate(self, prompts): 52 | sequence, output = self.local_model_chat_completion(prompts) 53 | # Skip special tokens and role tokens 54 | generated_text = self.tokenizer.batch_decode(sequence, skip_special_tokens=True) 55 | return generated_text 56 | 57 | def rate_score(self, prompts): 58 | sequence, output = self.local_model_chat_completion(prompts) 59 | # print(output.logits) 60 | # return sequence, output.logits 61 | scores, logprobs = [], [] 62 | for idx in range(sequence.shape[0]): 63 | seq_logits = [ 64 | logits[idx] for logits in output.logits 65 | ] # convert to [seq_len, vocab_size] 66 | score, logprob = self.extract_score(sequence[idx], seq_logits) 67 | scores.append(score) 68 | logprobs.append(logprob) 69 | return scores, logprobs 70 | 71 | def extract_score(self, sequence, logits): 72 | """ 73 | sequence: [batch_size, seq_len] 74 | logits: seq_len x [batch_size, vocab_size] 75 | output: int score 76 | """ 77 | for idx, token_id in enumerate(sequence): 78 | logit = logits[idx] 79 | logprobs = F.log_softmax(logit, dim=-1).cpu() 80 | score_logprobs = logprobs[self.score_ids].tolist() 81 | token = self.tokenizer.decode(token_id) 82 | if is_integer_string(token): 83 | return int(token), score_logprobs 84 | print("Failed to extract score") 85 | print(self.tokenizer.batch_decode(sequence)) 86 | return 3, [np.log(0.2)] * 5 87 | 88 | def extract_probs(self, sequence, logits) -> CompareResultObject: 89 | """ 90 | sequence: [batch_size, seq_len] 91 | logits: seq_len x [batch_size, vocab_size] 92 | output: compare_result_object 93 | """ 94 | # First token logit 95 | for idx, token_id in enumerate(sequence): 96 | if token_id in self.A_ids or token_id in self.B_ids: 97 | logit = logits[idx] 98 | probs = F.softmax(logit, dim=-1) 99 | prob_A = sum([probs[a_id].item() for a_id in self.A_ids]) 100 | prob_B = sum([probs[b_id].item() for b_id in self.B_ids]) 101 | prob_C = sum([probs[c_id].item() for c_id in self.C_ids]) 102 | logit_A = sum([logit[a_id].item() for a_id in self.A_ids]) 103 | logit_B = sum([logit[b_id].item() for b_id in self.B_ids]) 104 | logit_C = sum([logit[c_id].item() for c_id in self.C_ids]) 105 | uncertainty = calculate_uncertainty([prob_A, prob_B]) 106 | compare_result = CompareResultObject( 107 | raw_prob_A=prob_A, 108 | raw_prob_B=prob_B, 109 | raw_prob_C=prob_C, 110 | uncertainty=uncertainty, 111 | logit_A=logit_A, 112 | logit_B=logit_B, 113 | logit_C=logit_C, 114 | ) 115 | return compare_result 116 | print("Failed to extract probs") 117 | print(self.tokenizer.decode(sequence)) 118 | return CompareResultObject(raw_prob_A=0.5, raw_prob_B=0.5, uncertainty=1) 119 | 120 | def local_model_chat_completion(self, prompts): 121 | messages = [] 122 | for prompt in prompts: 123 | msg = MistralModelLocal.get_chat_message(prompt) 124 | msg = self.tokenizer.apply_chat_template( 125 | msg, tokenize=False 126 | ) # return_tensors="pt", return_dict=True) 127 | messages.append(msg) 128 | 129 | input = self.tokenizer(messages, return_tensors="pt", padding=True) 130 | input = input.to(device) 131 | output = self.model.generate( 132 | **input, 133 | return_dict_in_generate=True, 134 | pad_token_id=self.tokenizer.eos_token_id, 135 | output_logits=True, 136 | max_new_tokens=self.max_tokens, 137 | do_sample=self.do_sample, 138 | temperature=None, 139 | top_p=None 140 | ) 141 | 142 | newly_generated_tokens = output.sequences[:, input.input_ids.shape[-1] :] 143 | return newly_generated_tokens, output 144 | 145 | @staticmethod 146 | def get_chat_message(prompt, chat_system_instruction=None): 147 | if chat_system_instruction: 148 | message = [ 149 | # {'role': 'assistant', 'content': chat_system_instruction}, 150 | {"role": "user", "content": prompt}, 151 | ] 152 | else: 153 | message = [{"role": "user", "content": prompt}] 154 | return message 155 | 156 | 157 | if __name__ == "__main__": 158 | example_prompt = """\ 159 | Evaluate and compare the coherence of the two following summary candidates for the given input source text. 160 | 161 | Input source text: Paul Merson has restarted his row with Andros Townsend after the Tottenham midfielder was brought on with only seven minutes remaining in his team's 0-0 draw with Burnley on Sunday. 'Just been watching the game, did you miss the coach? #RubberDub #7minutes,' Merson put on Twitter. Merson initially angered Townsend for writing in his Sky Sports column that 'if Andros Townsend can get in (the England team) then it opens it up to anybody.' Paul Merson had another dig at Andros Townsend after his appearance for Tottenham against Burnley Townsend was brought on in the 83rd minute for Tottenham as they drew 0-0 against Burnley Andros Townsend scores England's equaliser in their 1-1 friendly draw with Italy in Turin on Tuesday night The former Arsenal man was proven wrong when Townsend hit a stunning equaliser for England against Italy and he duly admitted his mistake. 'It's not as though I was watching hoping he wouldn't score for England, I'm genuinely pleased for him and fair play to him – it was a great goal,' Merson said. 'It's just a matter of opinion, and my opinion was that he got pulled off after half an hour at Manchester United in front of Roy Hodgson, so he shouldn't have been in the squad. 'When I'm wrong, I hold my hands up. I don't have a problem with doing that - I'll always be the first to admit when I'm wrong.' Townsend hit back at Merson on Twitter after scoring for England against Italy Sky Sports pundit Merson (centre) criticised Townsend's call-up to the England squad last week Townsend hit back at Merson after netting for England in Turin on Wednesday, saying 'Not bad for a player that should be 'nowhere near the squad' ay @PaulMerse?' Any bad feeling between the pair seemed to have passed but Merson was unable to resist having another dig at Townsend after Tottenham drew at Turf Moor. 162 | 163 | Compare the following outputs: 164 | 165 | Summary candidate A: paul merson was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . andros townsend scored the tottenham midfielder in the 89th minute . paul merson had another dig at andros townsend after his appearance . the midfielder had been brought on to the england squad last week . click here for all the latest arsenal news news . 166 | 167 | Summary candidate B: paul merson has restarted his row with andros townsend . the tottenham midfielder was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . andros townsend scores england 's equaliser in their 1-1 friendly draw with italy in turin . 168 | 169 | Question: Which summary candidate has better coherence? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only. 170 | Answer: \ 171 | """ 172 | 173 | import os 174 | 175 | model = MistralModelLocal({"model": "mistralai/Mistral-7B-Instruct-v0.1"}) 176 | print(example_prompt) 177 | result = model.compare(example_prompt) 178 | -------------------------------------------------------------------------------- /models/llama3.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from transformers import LlamaForCausalLM, AutoTokenizer 3 | import torch 4 | import numpy as np 5 | from transformers import logging 6 | import torch.nn.functional as F 7 | import sys 8 | 9 | sys.path.append("../") 10 | from utils import CompareResultObject, calculate_uncertainty 11 | from prompts import get_cot_eval_prompt_template 12 | 13 | logging.set_verbosity_error() 14 | 15 | 16 | device = "cuda" 17 | 18 | 19 | def is_integer_string(s): 20 | return s.isdigit() 21 | 22 | 23 | class Llama3ModelLocal: 24 | def __init__(self, params): 25 | self.model_name = params["model"] 26 | self.temperature = params["temperature"] if "temperature" in params else 0 27 | self.max_tokens = params["max_tokens"] if "max_tokens" in params else 128 28 | self.do_sample = params["do_sample"] if "do_sample" in params else False 29 | self.do_cot = params["cot"] if "cot" in params else False 30 | self.cot_eval_template = None 31 | if self.do_cot: 32 | self.max_tokens = 350 33 | self.device = device 34 | if "cache_dir" not in params: 35 | params["cache_dir"] = None 36 | self.tokenizer = AutoTokenizer.from_pretrained( 37 | self.model_name, padding_side="left" 38 | ) 39 | self.tokenizer.pad_token = self.tokenizer.eos_token 40 | 41 | self.model = LlamaForCausalLM.from_pretrained( 42 | self.model_name, 43 | torch_dtype=torch.bfloat16, 44 | device_map=self.device, 45 | attn_implementation="flash_attention_2", 46 | ) 47 | self.model.eval() 48 | self.A_ids = self.tokenizer.convert_tokens_to_ids(["A", "ĠA"]) 49 | self.B_ids = self.tokenizer.convert_tokens_to_ids(["B", "ĠB"]) 50 | self.C_ids = self.tokenizer.convert_tokens_to_ids(["C", "ĠC"]) 51 | self.score_ids = self.tokenizer.convert_tokens_to_ids(["1", "2", "3", "4", "5"]) 52 | print("Model {} loaded".format(self.model_name)) 53 | 54 | def rate_score(self, prompt): 55 | sequence, output = self.local_model_chat_completion(prompt) 56 | score, logprobs = self.extract_score(sequence, output.logits) 57 | print(score) 58 | return score, logprobs 59 | 60 | def compare(self, prompts) -> List[CompareResultObject]: 61 | """ 62 | prompts: [batch_size, seq_len] 63 | output: a list of compare_result_object, [batch_size] 64 | """ 65 | sequence, output = self.local_model_chat_completion(prompts) 66 | compare_results = [] 67 | for idx in range(sequence.shape[0]): 68 | seq_logits = [ 69 | logits[idx] for logits in output.logits 70 | ] # convert to [seq_len, vocab_size] 71 | compare_result = self.extract_probs(sequence[idx], seq_logits) 72 | compare_results.append(compare_result) 73 | return compare_results 74 | 75 | def generate(self, prompts): 76 | sequence, output = self.local_model_chat_completion(prompts) 77 | # Skip special tokens and role tokens 78 | generated_text = self.tokenizer.batch_decode( 79 | sequence[:, 4:], skip_special_tokens=True 80 | ) 81 | return generated_text 82 | 83 | def debug_generate(self, prompts): 84 | A_id = self.tokenizer.convert_tokens_to_ids("A") 85 | B_id = self.tokenizer.convert_tokens_to_ids("B") 86 | print(A_id, B_id) 87 | sequence, output = self.local_model_chat_completion(prompts) 88 | generated_text = self.tokenizer.batch_decode( 89 | sequence[:, 4:], skip_special_tokens=True 90 | ) 91 | prob = torch.softmax(output.logits[4], dim=-1) 92 | print(prob) 93 | a_prob = prob[0, A_id] 94 | b_prob = prob[0, B_id] 95 | 96 | a_normp = a_prob / (a_prob + b_prob) 97 | b_normp = b_prob / (a_prob + b_prob) 98 | print(a_normp, b_normp) 99 | 100 | return sequence, output, generated_text, [a_normp, b_normp] 101 | 102 | def extract_score(self, sequence, logits): 103 | """ 104 | sequence: [batch_size, seq_len] 105 | logits: seq_len x [batch_size, vocab_size] 106 | output: int score 107 | """ 108 | for idx, token_id in enumerate(sequence[0]): 109 | logit = logits[idx][0] 110 | logprobs = F.log_softmax(logit, dim=-1).cpu() 111 | score_logprobs = logprobs[self.score_ids].tolist() 112 | token = self.tokenizer.decode(token_id) 113 | if is_integer_string(token): 114 | return int(token), score_logprobs 115 | print("Failed to extract score") 116 | print(self.tokenizer.batch_decode(sequence)) 117 | return 3, [np.log(0.2)] * 5 118 | 119 | def extract_probs(self, sequence, logits) -> CompareResultObject: 120 | """ 121 | sequence: [seq_len] 122 | logits: seq_len x [vocab_size] 123 | output: compare_result_object 124 | """ 125 | # First token logit 126 | for idx, token_id in enumerate(sequence): 127 | if token_id in self.A_ids or token_id in self.B_ids: 128 | logit = logits[idx] 129 | probs = F.softmax(logit, dim=-1) 130 | prob_A = sum([probs[a_id].item() for a_id in self.A_ids]) 131 | prob_B = sum([probs[b_id].item() for b_id in self.B_ids]) 132 | prob_C = sum([probs[c_id].item() for c_id in self.C_ids]) 133 | logit_A = sum([logit[a_id].item() for a_id in self.A_ids]) 134 | logit_B = sum([logit[b_id].item() for b_id in self.B_ids]) 135 | logit_C = sum([logit[c_id].item() for c_id in self.C_ids]) 136 | uncertainty = calculate_uncertainty([prob_A, prob_B]) 137 | compare_result = CompareResultObject( 138 | raw_prob_A=prob_A, 139 | raw_prob_B=prob_B, 140 | raw_prob_C=prob_C, 141 | uncertainty=uncertainty, 142 | logit_A=logit_A, 143 | logit_B=logit_B, 144 | logit_C=logit_C, 145 | ) 146 | return compare_result 147 | print("Failed to extract probs") 148 | print(self.tokenizer.batch_decode([sequence])) 149 | return CompareResultObject(raw_prob_A=0.5, raw_prob_B=0.5, uncertainty=1) 150 | 151 | def cot_compare(self, decoded_sequence): 152 | """ 153 | input: 154 | decoded_sequence: [batch_size, seq_len] 155 | output: 156 | compare_response: [batch_size, seq_len] 157 | """ 158 | pass 159 | 160 | def local_model_chat_completion(self, prompts): 161 | messages = [] 162 | for prompt in prompts: 163 | msg = Llama3ModelLocal.get_chat_message(prompt) 164 | msg = self.tokenizer.apply_chat_template( 165 | msg, tokenize=False 166 | ) # return_tensors="pt", return_dict=True) 167 | messages.append(msg) 168 | 169 | input = self.tokenizer(messages, return_tensors="pt", padding=True) 170 | input = input.to(device) 171 | output = self.model.generate( 172 | **input, 173 | return_dict_in_generate=True, 174 | pad_token_id=self.tokenizer.eos_token_id, 175 | output_logits=True, 176 | max_new_tokens=self.max_tokens, 177 | do_sample=self.do_sample, 178 | temperature=None, 179 | top_p=None 180 | ) 181 | 182 | newly_generated_tokens = output.sequences[:, input.input_ids.shape[-1] :] 183 | return newly_generated_tokens, output 184 | 185 | @staticmethod 186 | def get_chat_message(prompt, chat_system_instruction=None): 187 | if chat_system_instruction: 188 | message = [ 189 | {"role": "system", "content": chat_system_instruction}, 190 | {"role": "user", "content": prompt}, 191 | ] 192 | else: 193 | message = [{"role": "user", "content": prompt}] 194 | return message 195 | 196 | 197 | if __name__ == "__main__": 198 | example_prompt = """\ 199 | Evaluate and compare the coherence of the two following summary candidates for the given input source text. 200 | 201 | Input source text: Paul Merson has restarted his row with Andros Townsend after the Tottenham midfielder was brought on with only seven minutes remaining in his team's 0-0 draw with Burnley on Sunday. 'Just been watching the game, did you miss the coach? #RubberDub #7minutes,' Merson put on Twitter. Merson initially angered Townsend for writing in his Sky Sports column that 'if Andros Townsend can get in (the England team) then it opens it up to anybody.' Paul Merson had another dig at Andros Townsend after his appearance for Tottenham against Burnley Townsend was brought on in the 83rd minute for Tottenham as they drew 0-0 against Burnley Andros Townsend scores England's equaliser in their 1-1 friendly draw with Italy in Turin on Tuesday night The former Arsenal man was proven wrong when Townsend hit a stunning equaliser for England against Italy and he duly admitted his mistake. 'It's not as though I was watching hoping he wouldn't score for England, I'm genuinely pleased for him and fair play to him – it was a great goal,' Merson said. 'It's just a matter of opinion, and my opinion was that he got pulled off after half an hour at Manchester United in front of Roy Hodgson, so he shouldn't have been in the squad. 'When I'm wrong, I hold my hands up. I don't have a problem with doing that - I'll always be the first to admit when I'm wrong.' Townsend hit back at Merson on Twitter after scoring for England against Italy Sky Sports pundit Merson (centre) criticised Townsend's call-up to the England squad last week Townsend hit back at Merson after netting for England in Turin on Wednesday, saying 'Not bad for a player that should be 'nowhere near the squad' ay @PaulMerse?' Any bad feeling between the pair seemed to have passed but Merson was unable to resist having another dig at Townsend after Tottenham drew at Turf Moor. 202 | 203 | Compare the following outputs: 204 | 205 | Summary candidate A: paul merson was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . andros townsend scored the tottenham midfielder in the 89th minute . paul merson had another dig at andros townsend after his appearance . the midfielder had been brought on to the england squad last week . click here for all the latest arsenal news news . 206 | 207 | Summary candidate B: paul merson has restarted his row with andros townsend . the tottenham midfielder was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . andros townsend scores england 's equaliser in their 1-1 friendly draw with italy in turin . 208 | 209 | Question: Which summary candidate has better coherence? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only. 210 | Answer: \ 211 | """ 212 | 213 | import os 214 | 215 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 216 | 217 | model = Llama3ModelLocal({"model": "meta-llama/Meta-Llama-3-8B"}) 218 | 219 | result = model.compare(example_prompt) 220 | 221 | print(result.prob_A) 222 | -------------------------------------------------------------------------------- /pairs/mistral.py: -------------------------------------------------------------------------------- 1 | from mistralai.client import MistralClient 2 | from mistralai.models.chat_completion import ChatMessage 3 | import os 4 | import time 5 | import random 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | import torch 8 | import torch.nn.functional as F 9 | from utils import CompareResultObject, calculate_uncertainty 10 | import numpy as np 11 | 12 | 13 | device = "cuda" 14 | 15 | 16 | def is_integer_string(s): 17 | return s.isdigit() 18 | 19 | 20 | class MistralModelLocal: 21 | def __init__(self, params): 22 | self.model_name = params["model"] 23 | self.device = device 24 | self.model = AutoModelForCausalLM.from_pretrained( 25 | self.model_name, 26 | # cache_dir="models", 27 | device_map=self.device, 28 | attn_implementation="flash_attention_2", 29 | torch_dtype=torch.bfloat16, 30 | ) 31 | self.tokenizer = AutoTokenizer.from_pretrained( 32 | self.model_name, 33 | padding_side="left", 34 | ) # , cache_dir="models") 35 | self.tokenizer.pad_token = self.tokenizer.eos_token 36 | self.A_ids = self.tokenizer.convert_tokens_to_ids(["A", "▁A"]) # A: 330 37 | self.B_ids = self.tokenizer.convert_tokens_to_ids(["B", "▁B"]) # B: 365 38 | self.C_ids = self.tokenizer.convert_tokens_to_ids(["C", "▁C"]) # C: 39 | self.score_ids = self.tokenizer.convert_tokens_to_ids(["1", "2", "3", "4", "5"]) 40 | 41 | def compare(self, prompts): 42 | """ 43 | prompts: [batch_size, seq_len] 44 | output: a list of compare_result_object, [batch_size] 45 | """ 46 | sequence, output = self.local_model_chat_completion(prompts) 47 | compare_results = [] 48 | for idx in range(sequence.shape[0]): 49 | seq_logits = [ 50 | logits[idx] for logits in output.logits 51 | ] # convert to [seq_len, vocab_size] 52 | compare_result = self.extract_probs(sequence[idx], seq_logits) 53 | compare_results.append(compare_result) 54 | return compare_results 55 | 56 | # def mistral_rate(self, msg): 57 | # # model run locally 58 | # sequence, output = self.local_model_chat_completion(msg) 59 | # # print(sequence) 60 | # # print(self.tokenizer.batch_decode(sequence)) 61 | # score, logprobs = self.extract_score(sequence, output.logits) 62 | # return score, logprobs 63 | 64 | def rate_score(self, prompts): 65 | sequence, output = self.local_model_chat_completion(prompts) 66 | # print(output.logits) 67 | # return sequence, output.logits 68 | scores, logprobs = [], [] 69 | for idx in range(sequence.shape[0]): 70 | seq_logits = [ 71 | logits[idx] for logits in output.logits 72 | ] # convert to [seq_len, vocab_size] 73 | score, logprob = self.extract_score(sequence[idx], seq_logits) 74 | scores.append(score) 75 | logprobs.append(logprob) 76 | return scores, logprobs 77 | 78 | def extract_score(self, sequence, logits): 79 | """ 80 | sequence: [batch_size, seq_len] 81 | logits: seq_len x [batch_size, vocab_size] 82 | output: int score 83 | """ 84 | for idx, token_id in enumerate(sequence): 85 | logit = logits[idx] 86 | logprobs = F.log_softmax(logit, dim=-1).cpu() 87 | score_logprobs = logprobs[self.score_ids].tolist() 88 | token = self.tokenizer.decode(token_id) 89 | if is_integer_string(token): 90 | return int(token), score_logprobs 91 | print("Failed to extract score") 92 | print(self.tokenizer.batch_decode(sequence)) 93 | return 3, [np.log(0.2)] * 5 94 | # only string in the response: 95 | 96 | def extract_probs(self, sequence, logits) -> CompareResultObject: 97 | """ 98 | sequence: [batch_size, seq_len] 99 | logits: seq_len x [batch_size, vocab_size] 100 | output: compare_result_object 101 | """ 102 | # First token logit 103 | # print(self.tokenizer.batch_decode(sequence)) 104 | for idx, token_id in enumerate(sequence): 105 | if token_id in self.A_ids or token_id in self.B_ids: 106 | logit = logits[idx] 107 | probs = F.softmax(logit, dim=-1) 108 | prob_A = sum([probs[a_id].item() for a_id in self.A_ids]) 109 | prob_B = sum([probs[b_id].item() for b_id in self.B_ids]) 110 | prob_C = sum([probs[c_id].item() for c_id in self.C_ids]) 111 | # print(sequence) 112 | # print('raw prob_A: ', prob_A, 'raw prob_B: ', prob_B) 113 | # prob_A, prob_B = prob_A/(prob_A+prob_B), prob_B/(prob_A+prob_B) 114 | uncertainty = calculate_uncertainty([prob_A, prob_B]) 115 | compare_result = CompareResultObject( 116 | raw_prob_A=prob_A, 117 | raw_prob_B=prob_B, 118 | raw_prob_C=prob_C, 119 | uncertainty=uncertainty, 120 | ) 121 | return compare_result 122 | print("Failed to extract probs") 123 | print(self.tokenizer.batch_decode(sequence)) 124 | return CompareResultObject(raw_prob_A=0.5, raw_prob_B=0.5, uncertainty=1) 125 | 126 | # def local_model_chat_completion(self, msg): 127 | # msg_encoded = self.tokenizer.apply_chat_template(msg, return_tensors="pt", return_dict=True) 128 | # model_inputs = msg_encoded.to(self.device) 129 | # output = self.model.generate(**model_inputs, 130 | # max_new_tokens=64, 131 | # pad_token_id=self.tokenizer.eos_token_id, 132 | # do_sample=False, 133 | # return_dict_in_generate=True, 134 | # output_logits=True) 135 | # newly_generated_tokens = output.sequences[:,model_inputs.input_ids.shape[1]:] 136 | # return newly_generated_tokens, output 137 | 138 | def local_model_chat_completion(self, prompts, num_samples=1): 139 | # if num_samples>1: 140 | # prompts = [prompts]*num_samples 141 | messages = [] 142 | for prompt in prompts: 143 | msg = MistralModelLocal.get_chat_message(prompt) 144 | msg = self.tokenizer.apply_chat_template( 145 | msg, tokenize=False 146 | ) # return_tensors="pt", return_dict=True) 147 | messages.append(msg) 148 | 149 | input = self.tokenizer(messages, return_tensors="pt", padding=True) 150 | input = input.to(device) 151 | output = self.model.generate( 152 | **input, 153 | return_dict_in_generate=True, 154 | pad_token_id=self.tokenizer.eos_token_id, 155 | output_logits=True, 156 | max_new_tokens=32, 157 | do_sample=False, 158 | temperature=None, 159 | top_p=None 160 | ) 161 | 162 | newly_generated_tokens = output.sequences[:, input.input_ids.shape[-1] :] 163 | return newly_generated_tokens, output 164 | 165 | @staticmethod 166 | def get_chat_message(prompt): 167 | return [{"role": "user", "content": prompt}] 168 | 169 | 170 | class MistralModel: 171 | def __init__(self, params, api_key=None): 172 | self.model_name = params["model"] 173 | self.device = device 174 | if not api_key: 175 | self.api_key = os.environ.get("MISTRAL_API_KEY") 176 | else: 177 | self.api_key = api_key 178 | 179 | def mistral_compare(self, msg): 180 | # model run through api 181 | api_params = { 182 | "model": self.model_name, 183 | "temperature": 0.0, 184 | "max_tokens": 32, 185 | "attempt_num": 10, 186 | } 187 | chat_response = MistralModel.call_mistral_chat_completion( 188 | msg, api_params, api_key=self.api_key 189 | ) 190 | # print(chat_response) 191 | compare_result = self.extract_choice(chat_response) 192 | return compare_result 193 | 194 | def extract_choice(self, response) -> CompareResultObject: 195 | """ 196 | response: decoded str. 197 | output: compare_result_object 198 | """ 199 | # only string in the response: 200 | choice = MistralModel.first_appears_first(response, "A", "B") 201 | if choice not in ["A", "B"]: 202 | print("Failed to extract choice") 203 | return CompareResultObject(raw_prob_A=0.5, raw_prob_B=0.5, uncertainty=1) 204 | return CompareResultObject( 205 | raw_prob_A=float(choice == "A"), 206 | raw_prob_B=float(choice == "B"), 207 | ) 208 | 209 | @staticmethod 210 | def first_appears_first(string, char1, char2): 211 | index1 = string.find(char1) 212 | index2 = string.find(char2) 213 | if index1 == -1 and index2 == -1: 214 | return None # Neither character appears in the string 215 | elif index1 == -1: 216 | return char2 # Only char2 appears in the string 217 | elif index2 == -1: 218 | return char1 # Only char1 appears in the string 219 | elif index1 < index2: 220 | return char1 221 | else: 222 | return char2 223 | 224 | @staticmethod 225 | def call_mistral_chat_completion(msg, api_params, api_key=None): 226 | if "model" not in api_params: 227 | api_params["model"] = "mistral-large-latest" 228 | if "temperature" not in api_params: 229 | api_params["temperature"] = 0.0 230 | if "max_tokens" not in api_params: 231 | api_params["max_tokens"] = 32 232 | if "attempt_num" not in api_params: 233 | api_params["attempt_num"] = 10 234 | if not api_key: 235 | api_key = os.environ.get("MISTRAL_API_KEY") 236 | 237 | client = MistralClient(api_key=api_key) 238 | 239 | attempt = 0 240 | while attempt < api_params["attempt_num"]: 241 | try: 242 | chat_response = client.chat( 243 | model=api_params["model"], 244 | messages=msg, 245 | temperature=api_params["temperature"], 246 | max_tokens=api_params["max_tokens"], 247 | ) 248 | # For now the api can only return response string. 249 | return chat_response.choices[0].message.content 250 | except Exception as e: 251 | print(e) 252 | attempt += 1 253 | api_params["temperature"] += 0.1 254 | time.sleep(0.2) 255 | # Fail cases 256 | print("Fail case: Default randomly selection.") 257 | return random.choice(["A", "B"]) 258 | 259 | @staticmethod 260 | def get_mistral_chat_message( 261 | prompt, aspect, with_input=False, eval_method="pairwise comparison" 262 | ): 263 | if eval_method == "score": 264 | messages = [ 265 | {"role": "user", "content": prompt}, 266 | ] 267 | return messages 268 | else: 269 | messages = [ 270 | {"role": "user", "content": prompt}, 271 | ] 272 | return messages 273 | -------------------------------------------------------------------------------- /pairwise_comparison.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from utils import load_TopicalChat, load_summEval, load_gsm8k, load_newsroom 4 | from tqdm import tqdm 5 | from prompts import ( 6 | get_pairwise_prompt_template, 7 | get_cot_compare_prompt_template, 8 | get_cot_eval_prompt_template, 9 | ) 10 | from jinja2 import Environment 11 | from models.llama2 import Llama2ModelLocal 12 | from models.llama3 import Llama3ModelLocal 13 | from models.mistral import MistralModelLocal 14 | from models.openai_api import OpenAIChatModel 15 | import os 16 | 17 | 18 | def pairwise_non_diagonal_to_list(size): 19 | """ 20 | Convert a Square pairwise matrix to a list of non-diagonal elements 21 | """ 22 | rows, cols = size, size 23 | result = [] 24 | for i in range(rows): 25 | for j in range(cols): 26 | if i != j: # Check if the element is not on the diagonal 27 | result.append((i, j)) 28 | return result 29 | 30 | 31 | def list_to_pairwise_non_diagonal(non_diagonal_list, size): 32 | """ 33 | Convert a list of non-diagonal elements back to a pairwise matrix 34 | Square matrix 35 | """ 36 | rows, cols = size, size 37 | matrix = [[0] * cols for _ in range(rows)] 38 | index = 0 39 | for i in range(rows): 40 | for j in range(cols): 41 | if i != j: # Check if the element is not on the diagonal 42 | matrix[i][j] = non_diagonal_list[index] 43 | index += 1 44 | return np.array(matrix) 45 | 46 | 47 | def compute_pairwise_preference_matrix( 48 | model, 49 | input, 50 | output, 51 | prompt_templates, 52 | do_cot=False, 53 | worker_num=1, 54 | batch_size=1, 55 | instruction=None, 56 | ): 57 | """ 58 | worker_num is for closed source models async parallel processing 59 | batch_size is for open source models parallel processing 60 | """ 61 | response_size = len(output) 62 | full_pairwise_list = pairwise_non_diagonal_to_list(size=response_size) 63 | 64 | prompts = [] 65 | count = 0 66 | for pair in full_pairwise_list: 67 | if instruction: 68 | prompt = prompt_templates[0].render( 69 | input=input, 70 | output_1=output[pair[0]], 71 | output_2=output[pair[1]], 72 | instruction=instruction, 73 | ) 74 | else: 75 | prompt = prompt_templates[0].render( 76 | input=input, 77 | output_1=output[pair[0]], 78 | output_2=output[pair[1]], 79 | ) 80 | if count == 0: 81 | # print('Prompt:', prompt) 82 | count += 1 83 | prompts.append(prompt) 84 | 85 | # If CoT, generate first 86 | if do_cot: 87 | model.max_tokens = 256 88 | if worker_num > 1: # Parallel processing for closed source models 89 | prompts = model.generate(prompts) 90 | else: 91 | analysis = [] 92 | for i in tqdm(range(0, len(prompts), batch_size)): 93 | batch_prompts = prompts[i : i + batch_size] 94 | analysis.extend(model.generate(batch_prompts)) 95 | eval_promtps = [ 96 | prompt_templates[1].render(cot_response=decoded_sequence) 97 | for decoded_sequence in analysis 98 | ] 99 | prompts = eval_promtps 100 | 101 | model.max_tokens = 32 102 | if worker_num > 1: # Parallel processing for closed source models 103 | compare_result_list = model.compare(prompts, max_workers=worker_num) 104 | else: 105 | compare_result_list = [] 106 | for i in tqdm(range(0, len(prompts), batch_size)): 107 | batch_prompts = prompts[i : i + batch_size] 108 | compare_results = model.compare(batch_prompts) 109 | compare_result_list.extend(compare_results) 110 | 111 | pairwise_preference_list = ( 112 | [] 113 | ) # [compare_result.prob_A for compare_result in compare_result_list] 114 | norm_prob_list = [] 115 | logit_list = [] 116 | prob_list = [] 117 | for compare_result in compare_result_list: 118 | norm_prob_list.append( 119 | [compare_result.prob_A, compare_result.prob_B, compare_result.prob_C] 120 | ) 121 | logit_list.append( 122 | [compare_result.logit_A, compare_result.logit_B, compare_result.logit_C] 123 | ) 124 | prob_list.append( 125 | [ 126 | compare_result.raw_prob_A, 127 | compare_result.raw_prob_B, 128 | compare_result.raw_prob_C, 129 | ] 130 | ) 131 | if ( 132 | compare_result.prob_C > compare_result.prob_A 133 | and compare_result.prob_C > compare_result.prob_B 134 | ): 135 | pairwise_preference_list.append(0.5) 136 | else: 137 | pairwise_preference_list.append(compare_result.prob_A) 138 | 139 | # Count if prob is balanced 140 | A_cnt, B_cnt, C_cnt = 0, 0, 0 141 | for prob in pairwise_preference_list: 142 | if prob > 0.5: 143 | A_cnt += 1 144 | elif prob < 0.5: 145 | B_cnt += 1 146 | else: 147 | C_cnt += 1 148 | print("A_cnt:", A_cnt, "B_cnt:", B_cnt, "C_cnt:", C_cnt) 149 | 150 | pairwise_preference_matrix = list_to_pairwise_non_diagonal( 151 | pairwise_preference_list, response_size 152 | ) 153 | log_results = { 154 | "norm_prob_list": norm_prob_list, 155 | "logit_list": logit_list, 156 | "prob_list": prob_list, 157 | } 158 | return pairwise_preference_matrix, log_results 159 | 160 | 161 | def pairwise_compare(args, instruction_list, round_id): 162 | # Load datasets 163 | aspect_name = args.aspect_name 164 | if args.dataset == "TopicalChat": 165 | TC_path = "data/topicalchat_usr.json" 166 | input_doc, output_doc, scores = load_TopicalChat( 167 | TC_path, truncate_num_for_eval=args.eval_data_num 168 | ) 169 | scores = scores[args.aspect_name] 170 | 171 | elif args.dataset == "SummEval": 172 | SummEval_path = "data/model_annotations.aligned.paired.jsonl" 173 | input_doc, output_doc, scores = load_summEval( 174 | SummEval_path, flat_output=False, truncate_num_for_eval=args.eval_data_num 175 | ) 176 | scores = scores[args.aspect_name] 177 | 178 | elif args.dataset == "newsroom": 179 | newsroom_path = "data/newsroom/newsroom.json" 180 | input_doc, output_doc, scores = load_newsroom( 181 | newsroom_path, flat_output=False, truncate_num_for_eval=args.eval_data_num 182 | ) 183 | scores = scores[args.aspect_name] 184 | 185 | elif args.dataset == "GSM8k": 186 | GSM8k_path = "gsm8k_augment/{}_test_responses.jsonl".format( 187 | args.engine.split("/")[-1] 188 | ) 189 | input, output_doc = load_gsm8k(GSM8k_path, cot=False) 190 | input_doc = [[i] for i in input] 191 | response_size = len(output_doc[0]) 192 | 193 | # Load model 194 | if "mistral" in args.engine: 195 | model = MistralModelLocal({"model": args.engine}) 196 | elif "Llama-3" in args.engine: 197 | model = Llama3ModelLocal({"model": args.engine, "cot": args.do_cot}) 198 | elif "Llama-2" in args.engine: 199 | model = Llama2ModelLocal({"model": args.engine}) 200 | elif "gpt" in args.engine: 201 | model = OpenAIChatModel({"model": args.engine}) 202 | saving_dir = args.saving_dir 203 | # Load prompt template 204 | for i_id, instruction in enumerate(instruction_list): 205 | if i_id > -1 and i_id < 21: 206 | if args.do_cot: 207 | prompt_template = get_cot_compare_prompt_template(dataset=args.dataset) 208 | cot_eval_template = get_cot_eval_prompt_template() 209 | print("Prompt template:", cot_eval_template) 210 | environment = Environment() 211 | prompt_template = environment.from_string(prompt_template) 212 | environment = Environment() 213 | cot_eval_template = environment.from_string(cot_eval_template) 214 | prompt_templates = [prompt_template, cot_eval_template] 215 | else: 216 | prompt_template = get_pairwise_prompt_template( 217 | dataset=args.dataset, use_instruction=True 218 | ) 219 | print("Prompt template:", prompt_template) 220 | environment = Environment() 221 | prompt_template = environment.from_string(prompt_template) 222 | prompt_templates = [prompt_template] 223 | pairwise_preference_matrix_log = [] 224 | compare_result_log = {} 225 | for i in range(len(input_doc)): 226 | print("Data point:", i + 1, "out of", len(input_doc), "data points.") 227 | input = input_doc[i][0] 228 | output = output_doc[i] 229 | pairwise_preference_matrix, compare_results = ( 230 | compute_pairwise_preference_matrix( 231 | model, 232 | input, 233 | output, 234 | prompt_templates, 235 | do_cot=args.do_cot, 236 | worker_num=args.worker_num, 237 | batch_size=args.batch_size, 238 | instruction=instruction, 239 | ) 240 | ) 241 | pairwise_preference_matrix_log.append( 242 | pairwise_preference_matrix.tolist() 243 | ) 244 | compare_result_log[i] = compare_results 245 | print("Model: ", args.engine) 246 | 247 | if not os.path.exists(saving_dir): 248 | os.makedirs(saving_dir) 249 | saving_path = f"{saving_dir}{args.engine.split('/')[-1]}_preference_matrix_log_cot_{args.do_cot}_{aspect_name}_{i_id}.json" 250 | with open(saving_path, "w") as f: 251 | json.dump(pairwise_preference_matrix_log, f, indent=4) 252 | f.close() 253 | saving_path = f"{saving_dir}{args.engine.split('/')[-1]}_compare_result_log_cot_{args.do_cot}_{aspect_name}_{i_id}.json" 254 | with open(saving_path, "w") as f: 255 | json.dump(compare_result_log, f, indent=4) 256 | f.close() 257 | saving_path = ( 258 | f"{saving_dir}{args.engine.split('/')[-1]}_instruction_set_{aspect_name}.json" 259 | ) 260 | with open(saving_path, "w") as f: 261 | json.dump(instruction_list, f, indent=4) 262 | 263 | 264 | if __name__ == "__main__": 265 | import argparse 266 | 267 | parser = argparse.ArgumentParser() 268 | parser.add_argument("--dataset", type=str, default="SummEval") 269 | parser.add_argument( 270 | "--engine", type=str, default="mistralai/Mistral-7B-Instruct-v0.1" 271 | ) 272 | parser.add_argument("--worker_num", type=int, default=1) 273 | parser.add_argument("--batch_size", type=int, default=1) 274 | parser.add_argument("--do_cot", action="store_true", default=False) 275 | parser.add_argument("--aspect_name", type=str, default="coherence") 276 | parser.add_argument("--eval_data_num", type=int, default=10) 277 | 278 | args = parser.parse_args() 279 | 280 | if args.engine == "full": 281 | engine_list = [ 282 | "mistralai/Mistral-7B-Instruct-v0.1", 283 | "meta-llama/Llama-2-7b-chat-hf", 284 | "meta-llama/Llama-2-13b-chat-hf", 285 | "meta-llama/Meta-Llama-3-8B-Instruct", 286 | "gpt-3.5-turbo", 287 | ] 288 | else: 289 | engine_list = [args.engine] 290 | 291 | results_to_report = [] 292 | for engine in engine_list: 293 | args.engine = engine 294 | print("Engine: ", engine) 295 | if "gpt" in engine: 296 | args.worker_num = 8 297 | else: 298 | args.worker_num = 1 299 | pairwise_compare(args) 300 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | import numpy as np 4 | import math 5 | import scipy 6 | 7 | 8 | class CompareResultObject: 9 | def __init__( 10 | self, 11 | raw_prob_A=0, 12 | raw_prob_B=0, 13 | raw_prob_C=0, 14 | uncertainty=1, 15 | logit_A=0, 16 | logit_B=0, 17 | logit_C=0, 18 | ): 19 | self.raw_prob_A = raw_prob_A 20 | self.raw_prob_B = raw_prob_B 21 | self.raw_prob_C = raw_prob_C 22 | prob_sum = raw_prob_A + raw_prob_B + raw_prob_C 23 | self.prob_A = raw_prob_A / prob_sum 24 | self.prob_B = raw_prob_B / prob_sum 25 | self.prob_C = raw_prob_C / prob_sum 26 | self.uncertainty = uncertainty 27 | self.logit_A = logit_A 28 | self.logit_B = logit_B 29 | self.logit_C = logit_C 30 | 31 | def calibraet_shift(self, shifts): 32 | shifted_prob_A = self.raw_prob_A / np.exp(shifts["A"]) 33 | shifted_prob_B = self.raw_prob_B / np.exp(shifts["B"]) 34 | shifted_prob_C = self.raw_prob_C / np.exp(shifts["C"]) 35 | prob_sum = shifted_prob_A + shifted_prob_B + shifted_prob_C 36 | self.prob_A = shifted_prob_A / prob_sum 37 | self.prob_B = shifted_prob_B / prob_sum 38 | self.prob_C = shifted_prob_C / prob_sum 39 | 40 | def __str__(self) -> str: 41 | string = f"prob_A: {round(self.prob_A,2)}, prob_B: {round(self.prob_B,2)}, prob_C: {round(self.prob_C,2)}, uncertainty: {round(self.uncertainty,3)} \n" 42 | string += f"raw_prob_A: {round(self.raw_prob_A,2)}, raw_prob_B: {round(self.raw_prob_B,2)}, raw_prob_C: {round(self.raw_prob_C,2)}" 43 | return string 44 | 45 | def __getitem__(self, key): 46 | return getattr(self, key, None) 47 | 48 | def __setitem__(self, key, value): 49 | setattr(self, key, value) 50 | 51 | def to_json(self): 52 | return { 53 | "prob_A": float(self.prob_A), 54 | "prob_B": float(self.prob_B), 55 | "prob_C": float(self.prob_C), 56 | "uncertainty": float(self.uncertainty), 57 | "raw_prob_A": float(self.raw_prob_A), 58 | "raw_prob_B": float(self.raw_prob_B), 59 | "raw_prob_C": float(self.raw_prob_C), 60 | } 61 | 62 | @staticmethod 63 | def from_json(json_obj): 64 | instance = CompareResultObject( 65 | json_obj["prob_A"], 66 | json_obj["prob_B"], 67 | json_obj["prob_C"], 68 | json_obj["uncertainty"], 69 | ) 70 | instance.prob_A = json_obj["prob_A"] 71 | instance.prob_B = json_obj["prob_B"] 72 | instance.prob_C = json_obj["prob_C"] 73 | return instance 74 | 75 | 76 | def calculate_uncertainty(probablities): 77 | probablities = np.array(probablities) + 1e-9 78 | entropy = -np.sum(probablities * np.log(probablities)) 79 | return entropy 80 | 81 | 82 | ############################################################################ 83 | ###### Load datasets 84 | ############################################################################ 85 | def load_jsonl(file_path): 86 | with open(file_path) as f: 87 | lines = f.readlines() 88 | return [json.loads(line) for line in lines] 89 | 90 | 91 | def load_json(file_path): 92 | with open(file_path) as f: 93 | return json.load(f) 94 | 95 | 96 | def load_gsm8k(data_path, cot=False): 97 | data = load_jsonl(data_path) 98 | if cot: 99 | questions = [dp["question"] + "\nLet's think step by step." for dp in data] 100 | else: 101 | questions = [dp["question"] for dp in data] 102 | # responses_doc = [dp['responses'] for dp in data] 103 | responses_doc = [] 104 | for dp in data: 105 | responses = [] 106 | for r in dp["responses"]: 107 | responses.append(r.replace("\n\n", "\n")) 108 | responses_doc.append(responses) 109 | 110 | return questions, responses_doc 111 | 112 | 113 | def load_summEval(path, flat_output=True, truncate_num_for_eval=None): 114 | data_summ_eval = load_jsonl(path) 115 | 116 | if truncate_num_for_eval: 117 | data_summ_eval = data_summ_eval[: 16 * truncate_num_for_eval] 118 | 119 | input = [] 120 | for i in range(len(data_summ_eval)): 121 | input.append(data_summ_eval[i]["text"]) 122 | 123 | output = [] 124 | for i in range(len(data_summ_eval)): 125 | output.append(data_summ_eval[i]["decoded"]) 126 | 127 | # coherence 128 | coherence_scores = [] 129 | for i in range(len(data_summ_eval)): 130 | coherence = [ 131 | anootation["coherence"] 132 | for anootation in data_summ_eval[i]["expert_annotations"] 133 | ] 134 | coherence_scores.append(round(sum(coherence) / len(coherence), 0)) 135 | # turker_annotations 136 | # fluency 137 | fluency_scores = [] 138 | for i in range(len(data_summ_eval)): 139 | fluency = [ 140 | anootation["fluency"] 141 | for anootation in data_summ_eval[i]["expert_annotations"] 142 | ] 143 | fluency_scores.append(round(sum(fluency) / len(fluency), 1)) 144 | 145 | # relevance 146 | relevance_scores = [] 147 | for i in range(len(data_summ_eval)): 148 | relevance = [ 149 | anootation["relevance"] 150 | for anootation in data_summ_eval[i]["expert_annotations"] 151 | ] 152 | relevance_scores.append(round(sum(relevance) / len(relevance), 1)) 153 | 154 | # consistency 155 | consistency_scores = [] 156 | for i in range(len(data_summ_eval)): 157 | consistency = [ 158 | anootation["consistency"] 159 | for anootation in data_summ_eval[i]["expert_annotations"] 160 | ] 161 | consistency_scores.append(round(sum(consistency) / len(consistency), 1)) 162 | 163 | if flat_output: 164 | return ( 165 | input, 166 | output, 167 | { 168 | "coherence": coherence_scores, 169 | "fluency": fluency_scores, 170 | "relevance": relevance_scores, 171 | "consistency": consistency_scores, 172 | }, 173 | ) 174 | else: 175 | candidate_num = 16 176 | ( 177 | input_doc, 178 | output_doc, 179 | coherence_doc, 180 | fluency_doc, 181 | relevance_doc, 182 | consistency_doc, 183 | ) = ([], [], [], [], [], []) 184 | for i in range(0, len(input), candidate_num): 185 | input_doc.append(input[i : i + candidate_num]) 186 | output_doc.append(output[i : i + candidate_num]) 187 | coherence_doc.append(coherence_scores[i : i + candidate_num]) 188 | fluency_doc.append(fluency_scores[i : i + candidate_num]) 189 | relevance_doc.append(relevance_scores[i : i + candidate_num]) 190 | consistency_doc.append(consistency_scores[i : i + candidate_num]) 191 | 192 | return ( 193 | input_doc, 194 | output_doc, 195 | { 196 | "coherence": coherence_doc, 197 | "fluency": fluency_doc, 198 | "relevance": relevance_doc, 199 | "consistency": consistency_doc, 200 | }, 201 | ) 202 | 203 | 204 | def load_newsroom(path, flat_output=True, truncate_num_for_eval=None): 205 | with open(path, "r") as file: 206 | newsroom = json.load(file) 207 | file.close() 208 | 209 | data = newsroom 210 | if truncate_num_for_eval: 211 | data = data[: 7 * truncate_num_for_eval] 212 | input = [dp["source"].replace("

", " ") for dp in data] 213 | output = [dp["system_output"] for dp in data] 214 | coherence = [round(dp["scores"]["coherence"], 1) for dp in data] 215 | fluency = [round(dp["scores"]["fluency"], 1) for dp in data] 216 | informativeness = [round(dp["scores"]["informativeness"], 1) for dp in data] 217 | relevance = [round(dp["scores"]["relevance"], 1) for dp in data] 218 | if flat_output: 219 | return ( 220 | input, 221 | output, 222 | { 223 | "coherence": coherence, 224 | "fluency": fluency, 225 | "informativeness": informativeness, 226 | "relevance": relevance, 227 | }, 228 | ) 229 | else: 230 | candidate_num = 7 231 | ( 232 | input_doc, 233 | output_doc, 234 | coherence_doc, 235 | fluency_doc, 236 | informativeness_doc, 237 | relevance_doc, 238 | ) = ([], [], [], [], [], []) 239 | for i in range(0, len(input), candidate_num): 240 | input_doc.append(input[i : i + candidate_num]) 241 | output_doc.append(output[i : i + candidate_num]) 242 | coherence_doc.append(coherence[i : i + candidate_num]) 243 | fluency_doc.append(fluency[i : i + candidate_num]) 244 | informativeness_doc.append(informativeness[i : i + candidate_num]) 245 | relevance_doc.append(relevance[i : i + candidate_num]) 246 | 247 | return ( 248 | input_doc, 249 | output_doc, 250 | { 251 | "coherence": coherence_doc, 252 | "fluency": fluency_doc, 253 | "informativeness": informativeness_doc, 254 | "relevance": relevance_doc, 255 | }, 256 | ) 257 | 258 | 259 | def load_TopicalChat(path, truncate_num_for_eval=None): 260 | data = load_json(path) 261 | if truncate_num_for_eval: 262 | data = data[: 5 * truncate_num_for_eval] 263 | input_doc = [] 264 | output_doc = [] 265 | overall_doc = [] 266 | natural_doc = [] 267 | engaging_doc = [] 268 | for i in range(len(data)): 269 | input = [] 270 | facts = [] 271 | output = [] 272 | natural = [] 273 | overall = [] 274 | engaging = [] 275 | for r in data[i]["responses"]: 276 | # Process input string to conversational format 277 | input_string = data[i]["context"] 278 | input_list = input_string.split("\n") 279 | if input_list[-1] == "": 280 | input_list = input_list[:-1] 281 | 282 | input_list_with_user = [] 283 | for idx, line in enumerate(input_list): 284 | if idx % 2 == 0: 285 | input_list_with_user.append("Person 1: " + line) 286 | next_round_person = "Person 2: " 287 | if idx % 2 == 1: 288 | input_list_with_user.append("Person 2: " + line) 289 | next_round_person = "Person 1: " 290 | 291 | input.append("\n".join(input_list_with_user)) 292 | facts.append(data[i]["fact"]) 293 | output.append(next_round_person + r["response"].strip()) 294 | natural.append(round(sum(r["Natural"]) / len(r["Natural"]), 0)) 295 | overall.append(round(sum(r["Overall"]) / len(r["Overall"]), 0)) 296 | engaging.append(round(sum(r["Engaging"]) / len(r["Engaging"]), 0)) 297 | input_doc.append(input) 298 | output_doc.append(output) 299 | overall_doc.append(overall) 300 | natural_doc.append(natural) 301 | engaging_doc.append(engaging) 302 | 303 | return ( 304 | input_doc, 305 | output_doc, 306 | {"overall": overall_doc, "natural": natural_doc, "engaging": engaging_doc}, 307 | ) 308 | 309 | 310 | ############################################################################ 311 | ###### Helper Functions 312 | ############################################################################ 313 | 314 | 315 | def shuffle_lists(*args): 316 | """Shuffle multiple lists together and return the shuffled lists.""" 317 | # Check if all lists are of the same length 318 | if len(set(map(len, args))) != 1: 319 | raise ValueError("All lists must be of the same length") 320 | 321 | # Combine the lists element-wise 322 | combined_lists = list(zip(*args)) 323 | random.shuffle(combined_lists) 324 | 325 | # Unzip the combined list into separate lists 326 | shuffled_lists = zip(*combined_lists) 327 | return [list(lst) for lst in shuffled_lists] 328 | 329 | 330 | def calculate_correlation(reference_score, predicted_score, print_result=True): 331 | spearman_corr, _ = scipy.stats.spearmanr(reference_score, predicted_score) 332 | 333 | if math.isnan(spearman_corr): 334 | spearman_corr = ( 335 | 1 336 | if all(element == reference_score[0] for element in reference_score) 337 | else 0 338 | ) 339 | kendall_tau, _ = scipy.stats.kendalltau(reference_score, predicted_score) 340 | 341 | if print_result: 342 | print("Spearmans correlation: %.3f" % spearman_corr) 343 | print("Kendall tau: %.3f" % kendall_tau) 344 | # mae = mean_absolute_error(reference_score, predicted_score) 345 | # print('MAE: %.3f' % mae) 346 | return spearman_corr, kendall_tau # , mae 347 | -------------------------------------------------------------------------------- /pairs/prompts.py: -------------------------------------------------------------------------------- 1 | from textwrap import dedent 2 | 3 | 4 | def get_prompt_template( 5 | prompt_name, 6 | aspect="coherence", 7 | dataset="SummEval", 8 | model_name=None, 9 | with_input=False, 10 | ): 11 | adj_lookup = { 12 | "coherence": "coherent", 13 | "fluency": "fluent", 14 | "relevance": "relevant", 15 | "informativeness": "informative", 16 | "overall": "overall high-quality", 17 | "naturalness": "natural", 18 | "sensible": "sensible", 19 | "surprise": "surprising", 20 | "complexity": "complex", 21 | "consistency": "consistent", 22 | } 23 | 24 | if dataset in ["SummEval", "newsroom"]: 25 | task = "summarization" 26 | description = "summarization" 27 | text_name = ("summary", "Summary") 28 | text_names = "summaries" 29 | input_name = ("source text", "Source text") 30 | elif dataset in ["sfhot", "sfres"]: 31 | task = "d2t" 32 | description = "data-to-text generation" 33 | text_name = ("text", "Text") 34 | text_names = "texts" 35 | input_name = ("data", "Data") 36 | elif dataset in ["hanna"]: 37 | task = "d2t" 38 | description = "creative writing" 39 | text_name = ("story", "Story") 40 | text_names = "stories" 41 | 42 | ###################################################### Pairwise comparison ###################################################### 43 | if prompt_name == "pairwise comparison": 44 | if with_input: 45 | prompt = dedent( 46 | f"""\ 47 | {input_name[1]}: {{{{ input }}}} 48 | 49 | {{{{instruction}}}} 50 | {text_name[1]} candidate A: {{{{ output_1 }}}} 51 | {text_name[1]} candidate B: {{{{ output_2 }}}} 52 | 53 | Question: Which {text_name[0]} candidate is more {adj_lookup[aspect]}? \ 54 | If the {text_name[0]} A is more {adj_lookup[aspect]}, please return 'A'. \ 55 | If the {text_name[0]} B is more {adj_lookup[aspect]}, please return 'B'. \ 56 | Plese only return the choice. 57 | Answer: """ 58 | ) 59 | else: 60 | prompt = dedent( 61 | f"""\ 62 | {{{{instruction}}}} 63 | Which {text_name[0]} is more {adj_lookup[aspect]}? 64 | 65 | {text_name[1]} A: {{{{ output_1 }}}} 66 | 67 | {text_name[1]} B: {{{{ output_2 }}}} 68 | 69 | Question: If the {text_name[1]} A is more {adj_lookup[aspect]}, please return "A". \ 70 | If the {text_name[1]} B is more {adj_lookup[aspect]}, please return "B". You must only return the choice. 71 | Answer: """ 72 | ) 73 | 74 | ###################################################### Pairwise comparison 3-way ###################################################### 75 | elif prompt_name == "pairwise comparison 3-way": 76 | if with_input: 77 | prompt = dedent( 78 | f"""\ 79 | {{{{instruction}}}} 80 | 81 | {input_name[1]}: {{{{ input_1 }}}} 82 | 83 | Evaluate and compare the following {text_names}: 84 | 85 | {text_name[1]} A: {{{{ output_1 }}}} 86 | 87 | {text_name[1]} B: {{{{ output_2 }}}} 88 | 89 | Question: Which {text_name[0]} is more {adj_lookup[aspect]}? \ 90 | If the {text_name[0]} A is more {adj_lookup[aspect]}, please return 'A'. If the {text_name[0]} B is more {adj_lookup[aspect]}, please return 'B'. \ 91 | If both {text_names} are equally {adj_lookup[aspect]}, please return 'C'. Plese only return the choice. 92 | Answer: """ 93 | ) 94 | else: 95 | prompt = dedent( 96 | f"""\ 97 | {{{{instruction}}}} 98 | Which {text_name[0]} is more {adj_lookup[aspect]}? 99 | 100 | {text_name[1]} A: {{{{ output_1 }}}} 101 | 102 | {text_name[1]} B: {{{{ output_2 }}}} 103 | 104 | Question: If the {text_name[1]} A is more {adj_lookup[aspect]}, please return "A". \ 105 | If the {text_name[1]} B is more {adj_lookup[aspect]}, please return "B". You must only return the choice. 106 | Answer: """ 107 | ) 108 | 109 | ###################################################### Baseline score prompts ###################################################### 110 | elif prompt_name == "score": 111 | if with_input: 112 | prompt = dedent( 113 | f"""\ 114 | {{{{instruction}}}} 115 | 116 | {input_name[1]}: {{{{ input }}}} 117 | 118 | {text_name[1]}: {{{{ output }}}} 119 | 120 | Please rate on a scale from 1 to 5, where 1 represents very low {adj_lookup[aspect]}, \ 121 | and 5 indicates excellent {adj_lookup[aspect]}. You must only return an int score. 122 | Score: """ 123 | ) 124 | else: 125 | prompt = dedent( 126 | f"""\ 127 | {{{{instruction}}}} 128 | 129 | Evaluate the following {text_name[1]}. 130 | {text_name[1]}: {{{{ output }}}} 131 | 132 | Question: Please rate on a scale from 1 to 5, where 1 represents very low {adj_lookup[aspect]}, \ 133 | and 5 indicates excellent {adj_lookup[aspect]}. You must only return the int score. 134 | Score: """ 135 | ) 136 | 137 | return prompt 138 | 139 | 140 | def get_aspect_instruction( 141 | aspect, eval_method="pairwise comparison", dataset="SummEval" 142 | ): 143 | 144 | if dataset in ["SummEval", "newsroom"]: 145 | task = "summarization" 146 | description = "summarization" 147 | text_name = "summary" 148 | text_names = "summaries" 149 | 150 | elif dataset in ["sfhot", "sfres"]: 151 | task = "d2t" 152 | description = "data-to-text generation" 153 | text_name = "text" 154 | text_names = "texts" 155 | 156 | elif dataset in ["hanna"]: 157 | task = "d2t" 158 | description = "creative story writing" 159 | text_name = "story" 160 | text_names = "stories" 161 | 162 | else: 163 | print("Dataset not support") 164 | assert False 165 | 166 | instructions = { 167 | "coherence": { 168 | "score": f"Please evaluate the coherence of the following {text_name}. ", 169 | "pairwise comparison": f"Compare the coherence of the two following {text_names}. " 170 | f"Consider aspects such as clarity and logical flow. " 171 | f"A {text_name} is coherent if it accurately captures the key information from the article, " 172 | "and presents them in a clear manner.", 173 | }, 174 | "fluency": { 175 | "score": f"Please evaluate the fluency of the following {text_name}. ", 176 | "pairwise comparison": f"Evaluate and compare the fluency of the two following {text_names}. ", 177 | # f'A fluent {text_name} should use clear language that avoids redundancy and errors. ' 178 | # f'A fluent {text_name} should use appropriate transition words, connectors, and avoid abrupt. ' 179 | # f'A fluent {text_name} should use correct spelling, punctuation, and capitalization throughout the summary, ' 180 | # 'and follow the conventions of standard written English.', 181 | }, 182 | "relevance": { 183 | "score": f"Please evaluate the relevance of the following {text_name}. " 184 | f"A {text_name} is relevant if it captures the main points from the article, without leaving out any crucial details or adding any unnecessary or inaccurate ones. " 185 | f"A {text_name} is more relevant if it uses the same or similar terms and expressions as the article. " 186 | f"A {text_name} is less relevant if it omits some of the key facts from the article, or if it introduces irrelevant information that is not supported by the article.", 187 | "pairwise comparison": f"Evaluate and compare the relevance level of two {text_names}. " 188 | f"A {text_name} is relevant if it captures the main points from the article, without leaving out any crucial details or adding any unnecessary or inaccurate ones. " 189 | f"A {text_name} is more relevant if it uses the same or similar terms and expressions as the article. " 190 | f"A {text_name} is less relevant if it omits some of the key facts from the article, or if it introduces irrelevant information that is not supported by the article.", 191 | }, 192 | "informativeness": { 193 | "score": f"Please evaluate the informativeness of the following {text_name}. ", 194 | "pairwise comparison": f"Compare the performance of two {description} examples, especially focusing on informativeness. " 195 | f"Evaluate how each {text_name} converts their input text to natural language text, without omitting, adding, or distorting any facts.", 196 | }, 197 | "consistency": { 198 | "score": f"Please evaluate the consistency of the following {text_name}. " 199 | f"A {text_name} is consistent with the article if it faithfully reflects the main points, facts, and tone of the article. " 200 | f"A {text_name} is inconsistent if it introduces any errors, contradictions, or distortions of the original article.", 201 | "pairwise comparison": f"Evaluate and compare how two {text_names} consistently follow the source text. " 202 | f"A {text_name} is consistent with the article if it faithfully reflects the main points, facts, and tone of the article. " 203 | f"A {text_name} is inconsistent if it introduces any errors, contradictions, or distortions of the original article.", 204 | }, 205 | "naturalness": { 206 | "score": "Please evaluate the informativeness of the following passages. " 207 | "Please rate on a scale from 1 to 5, where 1 represents very low informativeness, " 208 | "and 5 indicates excellent informativeness. Your response should be in the format of a list of float numbers. " 209 | 'For example: "[2, 4, 3]"', 210 | "pairwise comparison": f"Evaluate the naturalness of two {text_names}. " 211 | "A sentence is natural if it is fluent, coherent, grammatical, and human-like.", 212 | }, 213 | "overall": { 214 | "score": None, 215 | "pairwise comparison": f"Please evaluate the overall quality of the {description} " 216 | f"Consider the coherence, fluency, relevance, and informativeness of the {text_names}. " 217 | f'If you think {text_name} A is better, please return "A". If you think {text_name} B is better, please return "B".', 218 | }, 219 | "sensible": { 220 | "score": f"Please evaluate the sensibility of the following {text_name}. " 221 | f"A {text_name} is sensible if the events are consistent and align with the context they are set in. " 222 | "A sensible story has good believability. " 223 | f"A {text_name} is not sensible if there are contradictions.", 224 | "pairwise comparison": f"Please evaluate and compare the sensibility of the following {text_names}. " 225 | f"A {text_name} is sensible if the events within each {text_name} are consistent and align with the context they are set in. " 226 | "A sensible story has good believability. " 227 | f"A {text_name} is not sensible if there are contradictions.", 228 | }, 229 | "surprise": { 230 | "score": f"Assess the given story based on its capacity to generate surprise. ", 231 | "pairwise comparison": f"Please evaluate and compare two {text_names} in terms of their ability " 232 | "to evoke surprise and unexpected plot twists. Consider the effectiveness of building suspense " 233 | "and anticipation, and the manipulation of reader expectations. ", 234 | }, 235 | "complexity": { 236 | "score": "Please evaluate the narrative complexity of the following creative story. " 237 | "Consider the complexity from the aspect of structure, character development, thematic depth, and stylistic elements employed in the story. ", 238 | "pairwise comparison": "Please evaluate and compare the narrative complexity of the following creative stories. " 239 | "Consider the complexity from the aspects of structure, character development, thematic depth, and stylistic elements employed in each story. ", 240 | }, 241 | } 242 | 243 | if (aspect in instructions) and (eval_method in instructions[aspect]): 244 | return instructions[aspect][eval_method] 245 | else: 246 | print("Aspect or evaluation method not supported.") 247 | return None 248 | 249 | 250 | if __name__ == "__main__": 251 | prompt = get_aspect_instruction("overall", eval_method="pairwise", dataset="sfhot") 252 | print(prompt) 253 | 254 | prompt_instruction = get_prompt_template( 255 | "pairwise comparison", "any", aspect="overall", dataset="sfhot" 256 | ) 257 | 258 | print(prompt_instruction) 259 | -------------------------------------------------------------------------------- /pairs/openai_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import List, Dict 4 | from openai import OpenAI 5 | import openai 6 | import json 7 | from jinja2 import Environment 8 | from textwrap import dedent 9 | from tqdm import tqdm 10 | import math 11 | from concurrent.futures import ThreadPoolExecutor 12 | import concurrent 13 | import threading 14 | import datetime 15 | from .utils import CompareResultObject, calculate_entropy 16 | import numpy as np 17 | 18 | openai_api_key = os.environ.get("OPENAI_API_KEY") 19 | 20 | 21 | class Timer(object): 22 | def __init__(self): 23 | self.__start = time.time() 24 | 25 | def start(self): 26 | self.__start = time.time() 27 | 28 | def get_time(self, restart=True, format=False): 29 | end = time.time() 30 | span = end - self.__start 31 | if restart: 32 | self.__start = end 33 | if format: 34 | return self.format(span) 35 | else: 36 | return span 37 | 38 | def format(self, seconds): 39 | return datetime.timedelta(seconds=int(seconds)) 40 | 41 | def print(self, name): 42 | print(name, self.get_time()) 43 | 44 | 45 | class OpenAIChatModel: 46 | def __init__(self, params={}, api_key=None): 47 | self.api_key = api_key 48 | if "engine" not in params: 49 | params["engine"] = "gpt-3.5-turbo" 50 | if "temperature" not in params: 51 | params["temperature"] = 0 52 | if "max_tokens" not in params: 53 | params["max_tokens"] = 128 54 | if "logprobs" not in params: 55 | params["logprobs"] = True 56 | if "top_logprobs" not in params: 57 | params["top_logprobs"] = 5 58 | if "attempt_num" not in params: 59 | params["attempt_num"] = 10 60 | # if 'do_sample' not in params: 61 | # params['do_sample'] = False 62 | if "chat_system_instruction" not in params: 63 | params["chat_system_instruction"] = None 64 | 65 | self.params = params 66 | if not api_key: 67 | api_key = os.getenv("OPENAI_API_KEY") 68 | self.client = OpenAI(api_key=api_key) 69 | 70 | def compare(self, prompts, max_workers=4): 71 | result_list = self.multi_threading_openai_chat_completion( 72 | prompts, self.single_call_compare, max_workers=max_workers 73 | ) 74 | # result_list = [CompareResultObject.from_json(x['result']) for x in result_list] 75 | return result_list 76 | 77 | def rate_score(self, prompts, max_workers=4): 78 | result_list = self.multi_threading_openai_chat_completion( 79 | prompts, self.single_call_rate_score, max_workers=max_workers 80 | ) 81 | return result_list, None 82 | 83 | def call_openai_chat_completion(self, prompt): 84 | if self.params["chat_system_instruction"]: 85 | msg = [ 86 | {"role": "system", "content": self.params["chat_system_instruction"]} 87 | ] 88 | else: 89 | msg = [] 90 | msg.append({"role": "user", "content": prompt}) 91 | attempt = 0 92 | while True: 93 | try: 94 | response = self.client.chat.completions.create( 95 | model=self.params["engine"], 96 | messages=msg, 97 | temperature=self.params["temperature"], 98 | max_tokens=self.params["max_tokens"], 99 | logprobs=self.params["logprobs"], 100 | top_logprobs=( 101 | self.params["top_logprobs"] if self.params["logprobs"] else None 102 | ), 103 | ) 104 | return response 105 | 106 | except Exception as e: 107 | print(e) 108 | print(response) 109 | attempt += 1 110 | if attempt >= self.params["attempt_num"]: 111 | return None 112 | wait_sec = 1 113 | time.sleep(wait_sec) 114 | 115 | def extract_prob(self, response) -> CompareResultObject: 116 | """For OpenAI models""" 117 | prob_A, prob_B, prob_C = 0, 0, 0 118 | for token_object in response.choices[0].logprobs.content: 119 | logprobs = [] 120 | for token_candidate in token_object.top_logprobs: 121 | logprobs.append(token_candidate.logprob) 122 | if prob_A == 0 and token_candidate.token.strip() == "A": 123 | prob_A = np.exp(token_candidate.logprob) 124 | elif prob_B == 0 and token_candidate.token.strip() == "B": 125 | prob_B = np.exp(token_candidate.logprob) 126 | elif prob_C == 0 and token_candidate.token.strip() == "C": 127 | prob_C = np.exp(token_candidate.logprob) 128 | 129 | if prob_A != 0 or prob_B != 0 or prob_C != 0: 130 | comparison_result = CompareResultObject( 131 | raw_prob_A=prob_A, 132 | raw_prob_B=prob_B, 133 | raw_prob_C=prob_C, 134 | uncertainty=1, 135 | ) 136 | return comparison_result 137 | 138 | print("Fail case") 139 | print(response.choices[0]) 140 | return CompareResultObject(raw_prob_A=0.5, raw_prob_B=0.5, uncertainty=1) 141 | 142 | def extract_scores(self, response): 143 | response_string = response.choices[0].message.content 144 | scores = ["1", "2", "3", "4", "5"] 145 | for idx, s in enumerate(scores): 146 | if s in response_string: 147 | return idx + 1 148 | print("Fail case, return 3") 149 | return 3 150 | 151 | def multi_threading_openai_chat_completion( 152 | self, prompts, single_thread_func_handler, max_workers=4 153 | ): 154 | inputs = [{"prompt": prompt} for prompt in prompts] 155 | timer = Timer() 156 | print(f"using model_{self.params['engine']}") 157 | print("Processing queires") 158 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 159 | futures = list( 160 | tqdm( 161 | executor.map(lambda x: single_thread_func_handler(x), inputs), 162 | total=len(prompts), 163 | ) 164 | ) 165 | print( 166 | "Average time after {0} samples: {1}".format( 167 | len(prompts), timer.get_time(restart=False) / len(prompts) 168 | ) 169 | ) 170 | print("Processed queries") 171 | 172 | result_list = [input["result"] for input in inputs] 173 | return result_list 174 | 175 | def single_call_compare(self, input): 176 | response = self.call_openai_chat_completion(input["prompt"]) 177 | compare_result = self.extract_prob(response) 178 | input["result"] = compare_result 179 | 180 | def single_call_rate_score(self, input): 181 | response = self.call_openai_chat_completion(input["prompt"]) 182 | score = self.extract_scores(response) 183 | input["result"] = score 184 | 185 | 186 | def call_openai_chat_completion(prompt, api_params, api_key=None): 187 | if "engine" not in api_params: 188 | api_params["engine"] = "gpt-3.5-turbo" 189 | if "temperature" not in api_params: 190 | api_params["temperature"] = 0.2 191 | if "max_tokens" not in api_params: 192 | api_params["max_tokens"] = 128 193 | if "logprobs" not in api_params: 194 | api_params["logprobs"] = False 195 | if "top_logprobs" not in api_params: 196 | api_params["top_logprobs"] = 5 197 | if "attempt_num" not in api_params: 198 | api_params["attempt_num"] = 10 199 | if api_key: 200 | client = OpenAI(api_key=api_key) 201 | else: 202 | client = OpenAI(api_key=openai_api_key) 203 | 204 | msg = [ 205 | { 206 | "role": "system", 207 | "content": "You are a helpful assistant. Please follow the user's instructions.", 208 | }, 209 | {"role": "user", "content": prompt}, 210 | ] 211 | attempt = 0 212 | while True: 213 | try: 214 | if "davinci" in api_params["engine"]: 215 | # text completion model 216 | response = client.completions.create( 217 | model=api_params["engine"], 218 | prompt=prompt, 219 | temperature=api_params["temperature"], 220 | max_tokens=api_params["max_tokens"], 221 | logprobs=( 222 | api_params["top_logprobs"] if api_params["logprobs"] else None 223 | ), 224 | echo=True, 225 | ) 226 | else: 227 | # chat model 228 | response = client.chat.completions.create( 229 | model=api_params["engine"], 230 | messages=msg, 231 | temperature=api_params["temperature"], 232 | max_tokens=api_params["max_tokens"], 233 | logprobs=api_params["logprobs"], 234 | top_logprobs=( 235 | api_params["top_logprobs"] if api_params["logprobs"] else None 236 | ), 237 | ) 238 | return response 239 | 240 | except Exception as e: 241 | print(e) 242 | print(response) 243 | print(response.choices[0].message.content.strip()) 244 | attempt += 1 245 | if attempt >= api_params["attempt_num"]: 246 | return None 247 | wait_sec = 0.1 248 | time.sleep(wait_sec) 249 | 250 | 251 | def extract_prob(response, api_params) -> CompareResultObject: 252 | """For OpenAI models""" 253 | if "instruct" in api_params["engine"]: 254 | # for text completion model 255 | for idx, token in enumerate(response.choices[0].logprobs.tokens): 256 | if token in ["A", "B"]: 257 | token_prob_candidate = response.choices[0].logprobs.top_logprobs[idx] 258 | prob_A = ( 259 | np.exp(token_prob_candidate["A"]) 260 | if "A" in token_prob_candidate 261 | else 0 262 | ) 263 | prob_B = ( 264 | np.exp(token_prob_candidate["B"]) 265 | if "B" in token_prob_candidate 266 | else 0 267 | ) 268 | break 269 | else: # for chat model 270 | # if params['eval_method'] == "pairwise with tie": 271 | 272 | prob_A, prob_B, prob_C = 0, 0, 0 273 | for token_object in response.choices[0].logprobs.content: 274 | logprobs = [] 275 | for token_candidate in token_object.top_logprobs: 276 | logprobs.append(token_candidate.logprob) 277 | if prob_A == 0 and token_candidate.token.strip() == "A": 278 | prob_A = np.exp(token_candidate.logprob) 279 | elif prob_B == 0 and token_candidate.token.strip() == "B": 280 | prob_B = np.exp(token_candidate.logprob) 281 | elif prob_C == 0 and token_candidate.token.strip() == "C": 282 | prob_C = np.exp(token_candidate.logprob) 283 | if prob_A != 0 or prob_B != 0 or prob_C != 0: 284 | comparison_result = CompareResultObject( 285 | raw_prob_A=prob_A, 286 | raw_prob_B=prob_B, 287 | raw_prob_C=prob_C, 288 | uncertainty=calculate_entropy(logprobs), 289 | ) 290 | return comparison_result 291 | # prob_A, prob_B, prob_C = prob_A/(prob_A+prob_B+prob_C), prob_B/(prob_A+prob_B+prob_C), prob_C/(prob_A+prob_B+prob_C) 292 | # = prob_A/(prob_A+prob_B), prob_B/(prob_A+prob_B) 293 | # return {'prob_A':prob_A, 'prob_B':prob_B, 'prob_C':prob_C, 'uncertainty':calculate_entropy(logprobs)} 294 | print("Fail case") 295 | print(response.choices[0]) 296 | return CompareResultObject(raw_prob_A=0.5, raw_prob_B=0.5, uncertainty=1) 297 | 298 | 299 | if __name__ == "__main__": 300 | 301 | example_prompt = """\ 302 | Evaluate and compare the coherence of the two following summary candidates for the given input source text. 303 | 304 | Input source text: Paul Merson has restarted his row with Andros Townsend after the Tottenham midfielder was brought on with only seven minutes remaining in his team's 0-0 draw with Burnley on Sunday. 'Just been watching the game, did you miss the coach? #RubberDub #7minutes,' Merson put on Twitter. Merson initially angered Townsend for writing in his Sky Sports column that 'if Andros Townsend can get in (the England team) then it opens it up to anybody.' Paul Merson had another dig at Andros Townsend after his appearance for Tottenham against Burnley Townsend was brought on in the 83rd minute for Tottenham as they drew 0-0 against Burnley Andros Townsend scores England's equaliser in their 1-1 friendly draw with Italy in Turin on Tuesday night The former Arsenal man was proven wrong when Townsend hit a stunning equaliser for England against Italy and he duly admitted his mistake. 'It's not as though I was watching hoping he wouldn't score for England, I'm genuinely pleased for him and fair play to him – it was a great goal,' Merson said. 'It's just a matter of opinion, and my opinion was that he got pulled off after half an hour at Manchester United in front of Roy Hodgson, so he shouldn't have been in the squad. 'When I'm wrong, I hold my hands up. I don't have a problem with doing that - I'll always be the first to admit when I'm wrong.' Townsend hit back at Merson on Twitter after scoring for England against Italy Sky Sports pundit Merson (centre) criticised Townsend's call-up to the England squad last week Townsend hit back at Merson after netting for England in Turin on Wednesday, saying 'Not bad for a player that should be 'nowhere near the squad' ay @PaulMerse?' Any bad feeling between the pair seemed to have passed but Merson was unable to resist having another dig at Townsend after Tottenham drew at Turf Moor. 305 | 306 | Compare the following outputs: 307 | 308 | Summary candidate A: paul merson was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . andros townsend scored the tottenham midfielder in the 89th minute . paul merson had another dig at andros townsend after his appearance . the midfielder had been brought on to the england squad last week . click here for all the latest arsenal news news . 309 | 310 | Summary candidate B: paul merson has restarted his row with andros townsend . the tottenham midfielder was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . andros townsend scores england 's equaliser in their 1-1 friendly draw with italy in turin . 311 | 312 | Question: Which summary candidate has better coherence? If the candidate A is better, please return 'A'. If the candidate B is better, please return 'B'. You must return the choice only. 313 | Answer: \ 314 | """ 315 | 316 | prompts = [example_prompt] * 3 317 | model = OpenAIChatModel({"engine": "gpt-3.5-turbo"}) 318 | result = model.compare(prompts) 319 | print(result) 320 | print(result.prob_A) 321 | print(result.uncertainty) 322 | -------------------------------------------------------------------------------- /pairs/sorting.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from .openai_api import call_openai_chat_completion, extract_prob 3 | from jinja2 import Environment 4 | from .prompts import get_prompt_template, get_aspect_instruction 5 | import numpy as np 6 | from collections import Counter 7 | from .utils import ( 8 | shuffle_lists, 9 | calculate_correlation, 10 | load_newsroom, 11 | load_summEval, 12 | calculate_uncertainty, 13 | load_hanna, 14 | ) 15 | from .utils import ( 16 | CompareResultObject, 17 | insert_index_to_anchors, 18 | get_calibration_shift, 19 | calculate_entropy, 20 | ) 21 | from .mistral import MistralModel, MistralModelLocal 22 | import json 23 | import copy 24 | import random 25 | 26 | 27 | def moving_average(sum, val, idx): 28 | scale = min(idx, 5) 29 | return sum * scale / (scale + 1) + val / (scale + 1) 30 | return sum * idx / min(idx + 1, 3) + val / min(idx + 1, 3) 31 | 32 | 33 | class BeamItem: 34 | def __init__(self, index_pathway=[], cum_prob=1, pointer_A=-1, pointer_B=-1): 35 | self.index_pathway = index_pathway 36 | self.cum_prob = cum_prob 37 | self.pointer_A = pointer_A 38 | self.pointer_B = pointer_B 39 | 40 | def __str__(self): 41 | return f"index_pathway: {self.index_pathway}, cum_prob: {self.cum_prob}" 42 | 43 | 44 | def is_better_than_prob(id1, id2, preference_matrix, permutate=False): 45 | if permutate: 46 | return (preference_matrix[id1][id2] + 1 - preference_matrix[id2][id1]) / 2 47 | else: 48 | return preference_matrix[id1][id2] 49 | 50 | 51 | def merge_sort_indices(preference_matrix, params, permutate=False): 52 | # if 'model' not in params: 53 | # if params['engine'] in ["mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2"]: 54 | # from mistral import MistralModelLocal 55 | # params['model'] = MistralModelLocal(params={'model': params['engine']}) 56 | 57 | # elif 'mistral' in params['engine'] or 'mixtral' in params['engine']: 58 | # from mistral import MistralModel 59 | # params['model'] = MistralModel(params={'model': params['engine']}) 60 | 61 | # elif 'llama' in params['engine']: 62 | # from llama2 import Llama2ModelLocal 63 | # params['model'] = Llama2ModelLocal(params={'model': params['engine']}) 64 | 65 | indices = list(range(len(preference_matrix))) 66 | random.shuffle(indices) 67 | indices = merge_sort( 68 | indices, 0, len(indices), preference_matrix, params, permutate=permutate 69 | ) 70 | return indices 71 | 72 | 73 | def merge_sort(indices, left, right, preference_matrix, params, permutate=False): 74 | if right - left > 1: 75 | mid = (left + right) // 2 76 | merge_sort(indices, left, mid, preference_matrix, params, permutate=permutate) 77 | merge_sort(indices, mid, right, preference_matrix, params, permutate=permutate) 78 | if params["confidence_beam"]: 79 | merge_with_confidence_beam( 80 | indices, 81 | left, 82 | mid, 83 | right, 84 | preference_matrix, 85 | params, 86 | permutate=permutate, 87 | ) 88 | else: 89 | merge( 90 | indices, 91 | left, 92 | mid, 93 | right, 94 | preference_matrix, 95 | params, 96 | permutate=permutate, 97 | ) 98 | return indices 99 | 100 | 101 | def merge(indices, left, mid, right, preference_matrix, params, permutate=False): 102 | left_copy = indices[left:mid] 103 | right_copy = indices[mid:right] 104 | 105 | i = 0 106 | j = 0 107 | k = left 108 | 109 | while i < len(left_copy) and j < len(right_copy): 110 | if "progress_bar" in params: 111 | params["progress_bar"].update(1) 112 | compare_result = is_better_than_prob( 113 | left_copy[i], 114 | right_copy[j], 115 | preference_matrix=preference_matrix, 116 | permutate=permutate, 117 | ) 118 | # params['compare_log'][(left_copy[i], right_copy[j])] = compare_result.to_json() #compare_result['prob_A'] 119 | params["api_call"] += 1 120 | # print(compare_result) 121 | if compare_result > 0.5: # ['prob_A']>compare_result['prob_B']: 122 | indices[k] = right_copy[j] 123 | j += 1 124 | else: 125 | indices[k] = left_copy[i] 126 | i += 1 127 | k += 1 128 | 129 | while i < len(left_copy): 130 | indices[k] = left_copy[i] 131 | i += 1 132 | k += 1 133 | 134 | while j < len(right_copy): 135 | indices[k] = right_copy[j] 136 | j += 1 137 | k += 1 138 | 139 | 140 | def get_likelihood_coefficient(N, p): 141 | x = [0, (N - 1) / 2, N - 1] 142 | y = [1, 1, 1] 143 | 144 | coefficients = np.polyfit(x, y, 2) # Fit a 3rd-degree polynomial curve 145 | func = np.poly1d(coefficients) 146 | return func(p) 147 | 148 | 149 | def merge_with_confidence_beam( 150 | indices, left, mid, right, preference_matrix, params, permutate=False 151 | ): 152 | def get_probA(i, j): 153 | if prob_A_matrix[i, j] == 0: 154 | compare_result = is_better_than_prob( 155 | left_copy[i], 156 | right_copy[j], 157 | preference_matrix=preference_matrix, 158 | permutate=permutate, 159 | ) 160 | prob_A_matrix[i, j] = compare_result # ['prob_A'] 161 | params["api_call"] += 1 162 | if "progress_bar" in params: 163 | params["progress_bar"].update(1) 164 | return prob_A_matrix[i, j] # , uncertainty_matrix[i,j] 165 | 166 | left_copy = indices[left:mid] 167 | right_copy = indices[mid:right] 168 | 169 | beam_size = params["beam_size"] 170 | prob_A_matrix = np.zeros( 171 | (len(left_copy), len(right_copy)) 172 | ) # prob_A_matrix[i, j] is the probability of A better than B 173 | # uncertainty_matrix = np.ones_like(prob_A_matrix) # uncertainty_matrix[i, j] is the uncertainty of A is better 174 | prob_A_matrix[0, 0] = get_probA(0, 0) 175 | prob_gap = params["prob_gap"] 176 | 177 | coef = get_likelihood_coefficient(right - left, 0) 178 | # coef=1 179 | if prob_A_matrix[0, 0] > 0.5 + prob_gap: 180 | beam = [ 181 | BeamItem( 182 | index_pathway=[("B", 0)], 183 | cum_prob=np.log(prob_A_matrix[0, 0] + 1e-9) * coef, 184 | pointer_B=0, 185 | ), 186 | ] 187 | elif prob_A_matrix[0, 0] < 0.5 - prob_gap: 188 | beam = [ 189 | BeamItem( 190 | index_pathway=[("A", 0)], 191 | cum_prob=np.log(1 - prob_A_matrix[0, 0] + 1e-9) * coef, 192 | pointer_A=0, 193 | ), 194 | ] 195 | else: 196 | beam = [ 197 | BeamItem( 198 | index_pathway=[("B", 0)], 199 | cum_prob=np.log(prob_A_matrix[0, 0] + 1e-9) * coef, 200 | pointer_B=0, 201 | ), 202 | BeamItem( 203 | index_pathway=[("A", 0)], 204 | cum_prob=np.log(1 - prob_A_matrix[0, 0] + 1e-9) * coef, 205 | pointer_A=0, 206 | ), 207 | ] 208 | # sort beam according to cum_prob 209 | beam.sort(key=lambda x: x.cum_prob, reverse=True) 210 | beam = beam[:beam_size] 211 | 212 | for i in range(len(left_copy) + len(right_copy) - 1): 213 | coef = np.round(get_likelihood_coefficient(right - left, i + 1), 5) 214 | # print(coef) 215 | new_beam = [] 216 | for beam_item in beam: 217 | for choice in ["A", "B"]: 218 | beam_item_copy = copy.deepcopy(beam_item) 219 | if ( 220 | beam_item_copy.pointer_A < len(left_copy) - 1 221 | and beam_item_copy.pointer_B < len(right_copy) - 1 222 | ) and not (i == len(left_copy) + len(right_copy) - 2): 223 | prob_A = get_probA( 224 | min(beam_item_copy.pointer_A + 1, len(left_copy) - 1), 225 | min(beam_item_copy.pointer_B + 1, len(right_copy) - 1), 226 | ) 227 | if (choice == "A" and prob_A > 0.5 + prob_gap) or ( 228 | choice == "B" and 1 - prob_A > 0.5 + prob_gap 229 | ): 230 | continue 231 | # beam_item_copy.cum_prob *= 1-prob_A if choice == 'A' else prob_A 232 | logprob = ( 233 | np.log(1 - prob_A + 1e-9) 234 | if choice == "A" 235 | else np.log(prob_A + 1e-9) 236 | ) 237 | beam_item_copy.cum_prob = moving_average( 238 | beam_item_copy.cum_prob, logprob * coef, i + 1 239 | ) 240 | 241 | beam_item_copy.pointer_A += 1 if choice == "A" else 0 242 | beam_item_copy.pointer_B += 1 if choice == "B" else 0 243 | 244 | if (beam_item_copy.pointer_A >= len(left_copy)) or ( 245 | beam_item_copy.pointer_B >= len(right_copy) 246 | ): 247 | continue 248 | 249 | current_step = ( 250 | choice, 251 | ( 252 | beam_item_copy.pointer_A 253 | if choice == "A" 254 | else beam_item_copy.pointer_B 255 | ), 256 | ) 257 | beam_item_copy.index_pathway.append(current_step) 258 | new_beam.append(beam_item_copy) 259 | 260 | # reduce beam 261 | new_beam.sort(key=lambda x: x.cum_prob, reverse=True) 262 | beam = new_beam[:beam_size] 263 | 264 | best_candidate = beam[0] 265 | sorted_index = [] 266 | for item in best_candidate.index_pathway: 267 | if item[0] == "A": 268 | sorted_index.append(left_copy[item[1]]) 269 | else: 270 | sorted_index.append(right_copy[item[1]]) 271 | indices[left:right] = sorted_index 272 | 273 | 274 | def binary_search_insert_index(input, output, params, anchors_idx, target_idx): 275 | left = 0 276 | right = len(anchors_idx) - 1 277 | 278 | while left <= right: 279 | mid = (left + right) // 2 280 | if "progress_bar" in params: 281 | params["progress_bar"].update(1) 282 | compare_result = is_better_than_prob( 283 | anchors_idx[mid], target_idx, input, output, params=params 284 | ) 285 | params["compare_log"][(anchors_idx[mid], target_idx)] = compare_result.to_json() 286 | params["api_call"] += 1 287 | 288 | if compare_result["prob_A"] > compare_result["prob_B"]: 289 | right = mid - 1 290 | else: 291 | left = mid + 1 292 | return left 293 | 294 | 295 | def merge_sort_with_scale(input, output, scores, params, sort_size=100): 296 | # step 1 sort initial subset 297 | sorted_anchor_indices = merge_sort_indices( 298 | input[:sort_size], output[:sort_size], params 299 | ) 300 | initial_scores = np.array(scores[:sort_size]) 301 | calculate_correlation(initial_scores[sorted_anchor_indices], list(range(sort_size))) 302 | params["progress_bar"].close() 303 | 304 | # step 2: Get anchor examples index 305 | # attemp 1, Use all initial indices as anchor 306 | 307 | # step 3: determine the rest of the data, binary search 308 | progress_bar = tqdm(total=len(input) - sort_size, desc="Processing") 309 | searech_result = [] 310 | for idx in range(sort_size, len(input)): 311 | insert_index = binary_search_insert_index( 312 | input, output, params, sorted_anchor_indices, idx 313 | ) 314 | searech_result.append(insert_index) 315 | progress_bar.update(1) 316 | progress_bar.close() 317 | 318 | # step 4: insert the rest of the data to the sorted indices 319 | sorted_full_indices = insert_index_to_anchors( 320 | sorted_anchor_indices, searech_result, sort_size 321 | ) 322 | 323 | return sorted_full_indices 324 | 325 | 326 | if __name__ == "__main__": 327 | import argparse 328 | 329 | parser = argparse.ArgumentParser() 330 | parser.add_argument("--dataset", type=str, default="SumEval") 331 | parser.add_argument("--save_path", type=str, default=None) 332 | parser.add_argument("--aspect", type=str, default="coherence") 333 | parser.add_argument("--eval_method", type=str, default="pairwise comparison") 334 | parser.add_argument("--scaling_anchor_size", type=int, default=0) 335 | parser.add_argument("--eval_size", type=int, default=300) 336 | parser.add_argument("--engine", type=str, default="gpt-3.5-turbo") 337 | parser.add_argument("--confidence_beam", action="store_true") 338 | parser.add_argument("--prob_gap", type=float, default=0.15) 339 | parser.add_argument("--beam_size", type=int, default=100) 340 | parser.add_argument("--with_input", action="store_true") 341 | parser.add_argument("--calibration", action="store_true") 342 | args = parser.parse_args() 343 | 344 | print("aspect:", args.aspect) 345 | print("engine:", args.engine) 346 | print("dataset:", args.dataset) 347 | print("confidence_beam:", args.confidence_beam) 348 | print("beam_size:", args.beam_size) 349 | print("calibration:", args.calibration) 350 | 351 | params = { 352 | "dataset": args.dataset, 353 | "engine": args.engine, 354 | "aspect": args.aspect, 355 | "eval_method": args.eval_method, 356 | "confidence_beam": args.confidence_beam, 357 | "beam_size": args.beam_size, 358 | "api_call": 0, 359 | "prob_gap": args.prob_gap, 360 | "with_input": args.with_input, 361 | "calibration": args.calibration, 362 | "compare_log": {}, 363 | } 364 | # Load the dataset 365 | if args.dataset == "SumEval": 366 | summ_eval_path = "data/SummEval/model_annotations.aligned.paired.jsonl" 367 | input, output, scores = load_summEval(summ_eval_path) 368 | elif args.dataset == "newsroom": 369 | newsroom_path = "data/newsroom/newsroom.json" 370 | input, output, scores = load_newsroom(newsroom_path) 371 | elif args.dataset == "hanna": 372 | hanna_path = "data/hanna/hanna_stories_annotations.csv" 373 | input, output, scores = load_hanna(hanna_path) 374 | else: 375 | print("Dataset not supported.") 376 | assert False 377 | 378 | scores = scores[args.aspect] 379 | intput, output, scores = shuffle_lists(input, output, scores) 380 | 381 | input, output, scores = ( 382 | input[: args.eval_size], 383 | output[: args.eval_size], 384 | scores[: args.eval_size], 385 | ) 386 | 387 | # Initialize the progress bar 388 | if params["confidence_beam"]: 389 | params["progress_bar"] = tqdm(total=int(len(input) ** 2), desc="Processing") 390 | else: 391 | params["progress_bar"] = tqdm( 392 | total=int(len(input) * np.log2(len(input))), desc="Processing" 393 | ) 394 | 395 | # Run the sorting algorithm 396 | if args.scaling_anchor_size == 0: 397 | ranking_indices = merge_sort_indices(input, output, params) 398 | else: 399 | ranking_indices = merge_sort_with_scale( 400 | input, output, scores, params, sort_size=args.scaling_anchor_size 401 | ) 402 | 403 | human_scores = np.array(scores) 404 | 405 | print( 406 | "dataset:", 407 | args.dataset, 408 | "aspect:", 409 | args.aspect, 410 | "eval_method:", 411 | args.eval_method, 412 | "beam_search", 413 | args.confidence_beam, 414 | "engine:", 415 | args.engine, 416 | "api_call:", 417 | params["api_call"], 418 | "beam_size:", 419 | params["beam_size"], 420 | "prob_gap:", 421 | params["prob_gap"], 422 | "scaling_anchor_size:", 423 | args.scaling_anchor_size, 424 | "eval_size:", 425 | len(scores), 426 | "with_input", 427 | args.with_input, 428 | ) 429 | calculate_correlation(human_scores[ranking_indices], list(range(len(human_scores)))) 430 | 431 | score_cnter = Counter(human_scores) 432 | modified_scores = [] 433 | for s in sorted(score_cnter.keys()): 434 | modified_scores += [s] * Counter(human_scores)[s] 435 | 436 | calculate_correlation( 437 | predicted_score=human_scores[ranking_indices], reference_score=modified_scores 438 | ) 439 | 440 | params["progress_bar"].close() 441 | 442 | # Save the result 443 | if args.save_path is not None: 444 | results = { 445 | "aspect": args.aspect, 446 | "confidence_beam": args.confidence_beam, 447 | "beam_size": params["beam_size"], 448 | "engine": args.engine, 449 | "dataset": args.dataset, 450 | "human_scores": scores, 451 | "gpt_ranking": ranking_indices, 452 | "compare_log": { 453 | str(key): val for key, val in params["compare_log"].items() 454 | }, 455 | } 456 | 457 | with open(args.save_path, "a") as f: 458 | json.dump(results, f) 459 | f.write("\n") 460 | -------------------------------------------------------------------------------- /pairs/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter 3 | from scipy.stats import norm, rankdata 4 | import random 5 | import scipy 6 | from sklearn.metrics import mean_absolute_error 7 | 8 | # from cdf_transform import transform_cdf_matching 9 | import json 10 | import pandas as pd 11 | import math 12 | 13 | 14 | class CompareResultObject: 15 | def __init__( 16 | self, 17 | raw_prob_A=0, 18 | raw_prob_B=0, 19 | raw_prob_C=0, 20 | uncertainty=1, 21 | logit_A=0, 22 | logit_B=0, 23 | logit_C=0, 24 | ): 25 | self.raw_prob_A = raw_prob_A 26 | self.raw_prob_B = raw_prob_B 27 | self.raw_prob_C = raw_prob_C 28 | prob_sum = raw_prob_A + raw_prob_B + raw_prob_C 29 | self.prob_A = raw_prob_A / prob_sum 30 | self.prob_B = raw_prob_B / prob_sum 31 | self.prob_C = raw_prob_C / prob_sum 32 | self.uncertainty = uncertainty 33 | self.logit_A = logit_A 34 | self.logit_B = logit_B 35 | self.logit_C = logit_C 36 | 37 | def calibraet_shift(self, shifts): 38 | shifted_prob_A = self.raw_prob_A / np.exp(shifts["A"]) 39 | shifted_prob_B = self.raw_prob_B / np.exp(shifts["B"]) 40 | shifted_prob_C = self.raw_prob_C / np.exp(shifts["C"]) 41 | prob_sum = shifted_prob_A + shifted_prob_B + shifted_prob_C 42 | self.prob_A = shifted_prob_A / prob_sum 43 | self.prob_B = shifted_prob_B / prob_sum 44 | self.prob_C = shifted_prob_C / prob_sum 45 | 46 | def __str__(self) -> str: 47 | string = f"prob_A: {round(self.prob_A,2)}, prob_B: {round(self.prob_B,2)}, prob_C: {round(self.prob_C,2)}, uncertainty: {round(self.uncertainty,3)} \n" 48 | string += f"raw_prob_A: {round(self.raw_prob_A,2)}, raw_prob_B: {round(self.raw_prob_B,2)}, raw_prob_C: {round(self.raw_prob_C,2)}" 49 | return string 50 | 51 | def __getitem__(self, key): 52 | return getattr(self, key, None) 53 | 54 | def __setitem__(self, key, value): 55 | setattr(self, key, value) 56 | 57 | def to_json(self): 58 | return { 59 | "prob_A": float(self.prob_A), 60 | "prob_B": float(self.prob_B), 61 | "prob_C": float(self.prob_C), 62 | "uncertainty": float(self.uncertainty), 63 | "raw_prob_A": float(self.raw_prob_A), 64 | "raw_prob_B": float(self.raw_prob_B), 65 | "raw_prob_C": float(self.raw_prob_C), 66 | } 67 | 68 | 69 | def get_calibration_shift(model_name, dataset, aspect): 70 | calibration_shift_file = f"./calibration_shift.json" 71 | with open(calibration_shift_file, "r") as file: 72 | calibration_shift = json.load(file) 73 | file.close() 74 | shifts = calibration_shift[model_name][dataset][aspect] 75 | return { 76 | "A": shifts["logprobA"] if "logprobA" in shifts else 0, 77 | "B": shifts["logprobB"] if "logprobB" in shifts else 0, 78 | "C": shifts["logprobC"] if "logprobC" in shifts else 0, 79 | } 80 | 81 | 82 | def shuffle_lists(*args): 83 | """Shuffle multiple lists together and return the shuffled lists.""" 84 | # Check if all lists are of the same length 85 | if len(set(map(len, args))) != 1: 86 | raise ValueError("All lists must be of the same length") 87 | 88 | # Combine the lists element-wise 89 | combined_lists = list(zip(*args)) 90 | random.shuffle(combined_lists) 91 | 92 | # Unzip the combined list into separate lists 93 | shuffled_lists = zip(*combined_lists) 94 | return [list(lst) for lst in shuffled_lists] 95 | 96 | 97 | def get_score_dist(scores): 98 | cnter = Counter(scores) 99 | return [cnter[i] / cnter.total() for i in range(1, 6)] 100 | 101 | 102 | def float_to_int(num): 103 | return int(round(float(num), 0)) 104 | 105 | 106 | def calculate_uncertainty(probablities): 107 | probablities = np.array(probablities) 108 | entropy = -np.sum(probablities * np.log(probablities)) 109 | return entropy 110 | 111 | 112 | def calculate_entropy(logprobs): 113 | """ 114 | logprobs: a list of logprobs 115 | """ 116 | return -np.sum(np.exp(logprobs) * logprobs) 117 | 118 | 119 | def insert_index_to_anchors(original_list, insert_elements, index_offset=0): 120 | """ 121 | original_list: Anchor list 122 | insert_elements: List of elements to be inserted 123 | The goal is to insert the index of the insert_elements at the position of the value of the insert_elements 124 | to the original_list. 125 | For example: 126 | original_list = ['a', 'b', 'c','d','e','f'] 127 | insert_elements = [4,2,1,3,5,3,1] 128 | index_offset = 1 129 | Result List: ['a', 3, 7, 'b', 2, 'c', 4, 6, 'd', 1, 'e', 5, 'f'] 130 | """ 131 | original_list = original_list.copy() 132 | insert_positions = np.sort(insert_elements)[::-1] 133 | insert_val = np.argsort(insert_elements)[::-1] 134 | 135 | for index, val in zip(insert_positions, insert_val): 136 | original_list.insert(index, val + index_offset) 137 | return [int(num) for num in original_list] 138 | 139 | 140 | ############################################################################ 141 | ###### Load datasets 142 | ############################################################################ 143 | def load_jsonl(file_path): 144 | with open(file_path) as f: 145 | lines = f.readlines() 146 | return [json.loads(line) for line in lines] 147 | 148 | 149 | def load_json(file_path): 150 | with open(file_path) as f: 151 | return json.load(f) 152 | 153 | 154 | def load_summEval(path, flat_output=True, truncate_num_for_eval=None): 155 | data_summ_eval = load_jsonl(path) 156 | if truncate_num_for_eval: 157 | data_summ_eval = data_summ_eval[: 16 * truncate_num_for_eval] 158 | 159 | input = [] 160 | for i in range(len(data_summ_eval)): 161 | input.append(data_summ_eval[i]["text"]) 162 | 163 | output = [] 164 | for i in range(len(data_summ_eval)): 165 | output.append(data_summ_eval[i]["decoded"]) 166 | 167 | # coherence 168 | coherence_scores = [] 169 | for i in range(len(data_summ_eval)): 170 | coherence = [ 171 | anootation["coherence"] 172 | for anootation in data_summ_eval[i]["expert_annotations"] 173 | ] 174 | coherence_scores.append(round(sum(coherence) / len(coherence), 1)) 175 | # turker_annotations 176 | # fluency 177 | fluency_scores = [] 178 | for i in range(len(data_summ_eval)): 179 | fluency = [ 180 | anootation["fluency"] 181 | for anootation in data_summ_eval[i]["expert_annotations"] 182 | ] 183 | fluency_scores.append(round(sum(fluency) / len(fluency), 1)) 184 | 185 | # relevance 186 | relevance_scores = [] 187 | for i in range(len(data_summ_eval)): 188 | relevance = [ 189 | anootation["relevance"] 190 | for anootation in data_summ_eval[i]["expert_annotations"] 191 | ] 192 | relevance_scores.append(round(sum(relevance) / len(relevance), 1)) 193 | 194 | # consistency 195 | consistency_scores = [] 196 | for i in range(len(data_summ_eval)): 197 | consistency = [ 198 | anootation["consistency"] 199 | for anootation in data_summ_eval[i]["expert_annotations"] 200 | ] 201 | consistency_scores.append(round(sum(consistency) / len(consistency), 1)) 202 | 203 | if flat_output: 204 | return ( 205 | input, 206 | output, 207 | { 208 | "coherence": coherence_scores, 209 | "fluency": fluency_scores, 210 | "relevance": relevance_scores, 211 | "consistency": consistency_scores, 212 | }, 213 | ) 214 | else: 215 | candidate_num = 16 216 | ( 217 | input_doc, 218 | output_doc, 219 | coherence_doc, 220 | fluency_doc, 221 | relevance_doc, 222 | consistency_doc, 223 | ) = ([], [], [], [], [], []) 224 | for i in range(0, len(input), candidate_num): 225 | input_doc.append(input[i : i + candidate_num]) 226 | output_doc.append(output[i : i + candidate_num]) 227 | coherence_doc.append(coherence_scores[i : i + candidate_num]) 228 | fluency_doc.append(fluency_scores[i : i + candidate_num]) 229 | relevance_doc.append(relevance_scores[i : i + candidate_num]) 230 | consistency_doc.append(consistency_scores[i : i + candidate_num]) 231 | 232 | return ( 233 | input_doc, 234 | output_doc, 235 | { 236 | "coherence": coherence_doc, 237 | "fluency": fluency_doc, 238 | "relevance": relevance_doc, 239 | "consistency": consistency_doc, 240 | }, 241 | ) 242 | 243 | 244 | def load_newsroom(path, flat_output=True, truncate_num_for_eval=None): 245 | with open(path, "r") as file: 246 | newsroom = json.load(file) 247 | file.close() 248 | 249 | data = newsroom 250 | if truncate_num_for_eval: 251 | data = data[: 7 * truncate_num_for_eval] 252 | input = [dp["source"].replace("

", " ") for dp in data] 253 | output = [dp["system_output"] for dp in data] 254 | coherence = [round(dp["scores"]["coherence"], 1) for dp in data] 255 | fluency = [round(dp["scores"]["fluency"], 1) for dp in data] 256 | informativeness = [round(dp["scores"]["informativeness"], 1) for dp in data] 257 | relevance = [round(dp["scores"]["relevance"], 1) for dp in data] 258 | if flat_output: 259 | return ( 260 | input, 261 | output, 262 | { 263 | "coherence": coherence, 264 | "fluency": fluency, 265 | "informativeness": informativeness, 266 | "relevance": relevance, 267 | }, 268 | ) 269 | else: 270 | candidate_num = 7 271 | ( 272 | input_doc, 273 | output_doc, 274 | coherence_doc, 275 | fluency_doc, 276 | informativeness_doc, 277 | relevance_doc, 278 | ) = ([], [], [], [], [], []) 279 | for i in range(0, len(input), candidate_num): 280 | input_doc.append(input[i : i + candidate_num]) 281 | output_doc.append(output[i : i + candidate_num]) 282 | coherence_doc.append(coherence[i : i + candidate_num]) 283 | fluency_doc.append(fluency[i : i + candidate_num]) 284 | informativeness_doc.append(informativeness[i : i + candidate_num]) 285 | relevance_doc.append(relevance[i : i + candidate_num]) 286 | 287 | return ( 288 | input_doc, 289 | output_doc, 290 | { 291 | "coherence": coherence_doc, 292 | "fluency": fluency_doc, 293 | "informativeness": informativeness_doc, 294 | "relevance": relevance_doc, 295 | }, 296 | ) 297 | 298 | 299 | def load_TopicalChat(path, truncate_num_for_eval=None): 300 | data = load_json(path) 301 | if truncate_num_for_eval: 302 | data = data[: 5 * truncate_num_for_eval] 303 | input_doc = [] 304 | output_doc = [] 305 | overall_doc = [] 306 | natural_doc = [] 307 | engaging_doc = [] 308 | for i in range(len(data)): 309 | input = [] 310 | facts = [] 311 | output = [] 312 | natural = [] 313 | overall = [] 314 | engaging = [] 315 | for r in data[i]["responses"]: 316 | # Process input string to conversational format 317 | input_string = data[i]["context"] 318 | input_list = input_string.split("\n") 319 | if input_list[-1] == "": 320 | input_list = input_list[:-1] 321 | 322 | input_list_with_user = [] 323 | for idx, line in enumerate(input_list): 324 | if idx % 2 == 0: 325 | input_list_with_user.append("Person 1: " + line) 326 | next_round_person = "Person 2: " 327 | if idx % 2 == 1: 328 | input_list_with_user.append("Person 2: " + line) 329 | next_round_person = "Person 1: " 330 | 331 | input.append("\n".join(input_list_with_user)) 332 | facts.append(data[i]["fact"]) 333 | output.append(next_round_person + r["response"].strip()) 334 | natural.append(round(sum(r["Natural"]) / len(r["Natural"]), 0)) 335 | overall.append(round(sum(r["Overall"]) / len(r["Overall"]), 0)) 336 | engaging.append(round(sum(r["Engaging"]) / len(r["Engaging"]), 0)) 337 | input_doc.append(input) 338 | output_doc.append(output) 339 | overall_doc.append(overall) 340 | natural_doc.append(natural) 341 | engaging_doc.append(engaging) 342 | 343 | return ( 344 | input_doc, 345 | output_doc, 346 | {"overall": overall_doc, "natural": natural_doc, "engaging": engaging_doc}, 347 | ) 348 | 349 | 350 | def load_sf_data(file_path): 351 | """ 352 | Load SFHOT/SFRES data 353 | """ 354 | data = load_json(file_path) 355 | input = [dp["source"] for dp in data] 356 | output = [dp["system_output"] for dp in data] 357 | naturalness = [dp["scores"]["naturalness"] for dp in data] 358 | informativeness = [dp["scores"]["informativeness"] for dp in data] 359 | overall = [dp["scores"]["overall"] for dp in data] 360 | return ( 361 | input, 362 | output, 363 | { 364 | "naturalness": naturalness, 365 | "informativeness": informativeness, 366 | "overall": overall, 367 | }, 368 | ) 369 | 370 | 371 | def load_hanna(file_path): 372 | """ 373 | Load Hanna data 374 | """ 375 | try: 376 | dataset = pd.read_csv(file_path) 377 | except: 378 | file_path = "data/hanna_stories_annotations.csv" 379 | dataset = pd.read_csv(file_path) 380 | 381 | processed_df = {} 382 | for i in range(dataset.shape[0]): 383 | idx = dataset["Story ID"][i] 384 | if str(idx) not in processed_df: 385 | processed_df[str(idx)] = { 386 | "input": dataset["Prompt"][i], 387 | "output": dataset["Prompt"][i] + " " + dataset["Story"][i], 388 | "relevance": [dataset["Relevance"][i]], 389 | "coherence": [dataset["Coherence"][i]], 390 | "empathy": [dataset["Empathy"][i]], 391 | "surprise": [dataset["Surprise"][i]], 392 | "engagement": [dataset["Engagement"][i]], 393 | "complexity": [dataset["Complexity"][i]], 394 | } 395 | else: 396 | processed_df[str(idx)]["relevance"].append(dataset["Relevance"][i]) 397 | processed_df[str(idx)]["coherence"].append(dataset["Coherence"][i]) 398 | processed_df[str(idx)]["empathy"].append(dataset["Empathy"][i]) 399 | processed_df[str(idx)]["surprise"].append(dataset["Surprise"][i]) 400 | processed_df[str(idx)]["engagement"].append(dataset["Engagement"][i]) 401 | processed_df[str(idx)]["complexity"].append(dataset["Complexity"][i]) 402 | 403 | input = [dp["input"] for dp in list(processed_df.values())] 404 | output = [dp["output"] for dp in list(processed_df.values())] 405 | relevance = [ 406 | round(np.mean(dp["relevance"]), 1) for dp in list(processed_df.values()) 407 | ] 408 | coherence = [ 409 | round(np.mean(dp["coherence"]), 1) for dp in list(processed_df.values()) 410 | ] 411 | empathy = [round(np.mean(dp["empathy"]), 1) for dp in list(processed_df.values())] 412 | surprise = [round(np.mean(dp["surprise"]), 1) for dp in list(processed_df.values())] 413 | engagement = [ 414 | round(np.mean(dp["engagement"]), 1) for dp in list(processed_df.values()) 415 | ] 416 | complexity = [ 417 | round(np.mean(dp["complexity"]), 1) for dp in list(processed_df.values()) 418 | ] 419 | scores = { 420 | "relevance": relevance, # how well the story matches its prompt 421 | "sensible": coherence, # how much the story makes sense 422 | "empathy": empathy, # how well the reader understood the character’s emotions, derived from the importance of emotional commentary 423 | "surprise": surprise, # how surprising the end of the story was, derived from the importance of schema violation, or unexpectedness 424 | "engagement": engagement, # how much the reader engaged with the story; 425 | "complexity": complexity, # how elaborate the story is; derived from the importance of detailed descriptions and sophisticated problem-solving 426 | } 427 | return input, output, scores 428 | 429 | 430 | ############################################################################ 431 | ###### Correlation Analysis 432 | ############################################################################ 433 | 434 | 435 | def calculate_correlation(reference_score, predicted_score): 436 | spearman_corr, _ = scipy.stats.spearmanr(reference_score, predicted_score) 437 | 438 | if math.isnan(spearman_corr): 439 | # print(reference_score, predicted_score) 440 | # print(sum(reference_score), sum(predicted_score)) 441 | spearman_corr = ( 442 | 1 443 | if all(element == reference_score[0] for element in reference_score) 444 | else 0 445 | ) 446 | # print('Spearmans correlation: %.3f' % spearman_corr) 447 | kendall_tau, _ = scipy.stats.kendalltau(reference_score, predicted_score) 448 | # print('Kendall tau: %.3f' % kendall_tau) 449 | # mae = mean_absolute_error(reference_score, predicted_score) 450 | # print('MAE: %.3f' % mae) 451 | return spearman_corr, kendall_tau # , mae 452 | 453 | 454 | def correlation_analysis(results): 455 | print("Uncalibrated scores:") 456 | spearman_original, _, _ = calculate_correlation( 457 | results["human_scores"], results["pred_scores"] 458 | ) 459 | print("------------------") 460 | 461 | print("Uncalibrated G-Eval:") 462 | weights = np.array([1, 2, 3, 4, 5]) 463 | weighted_gpt_scores = np.exp(results["pred_logprob"]).T @ weights 464 | weighted_gpt_scores = np.round(weighted_gpt_scores, decimals=1) 465 | spearman_geval, _, _ = calculate_correlation( 466 | results["human_scores"], weighted_gpt_scores 467 | ) 468 | print("------------------") 469 | 470 | return spearman_original, spearman_geval 471 | --------------------------------------------------------------------------------