├── figs └── long_context_figure.JPG ├── requirements.txt ├── api_config.py ├── run_write_score.py ├── summeval_prompts ├── rel_detailed.txt ├── faith_detailed.txt └── con_detailed.txt ├── run.sh ├── README.md ├── inference.py └── utils.py /figs/long_context_figure.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmis-lab/ETHIC/HEAD/figs/long_context_figure.JPG -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | vllm==0.6.1 2 | transformers==4.44.2 3 | datasets 4 | google-generativeai==0.8.3 5 | openai==1.40.2 -------------------------------------------------------------------------------- /api_config.py: -------------------------------------------------------------------------------- 1 | CONFIG = { 2 | "openai": [ 3 | "", 4 | ], 5 | "google": [ 6 | "" 7 | ], 8 | } -------------------------------------------------------------------------------- /run_write_score.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from utils import write_score_file 3 | 4 | def main(): 5 | if len(sys.argv) < 3: 6 | print("Usage: python run_write_score.py ") 7 | sys.exit(1) 8 | 9 | task = sys.argv[1] 10 | save_path = sys.argv[2] 11 | 12 | write_score_file(task, save_path) 13 | 14 | if __name__ == "__main__": 15 | main() 16 | -------------------------------------------------------------------------------- /summeval_prompts/rel_detailed.txt: -------------------------------------------------------------------------------- 1 | # Instruction: 2 | Below is an instruction for evaluating the relevance of the generated summary to the source document. Relevance measures 3 | whether a summary contains the main ideas of the source. The goal is to score relevance on a scale of 1-5, with 1 being 4 | not relevant at all, and 5 being highly relevant. 5 | 6 | # Evaluation Criteria: 7 | 1. Not Relevant: The summary doesn’t capture any of the main ideas of the source. 8 | 2. Barely Relevant: The summary captures very few of the main ideas of the source. 9 | 3. Somewhat Relevant: The summary captures some, but not all, of the main ideas of the source. 10 | 4. Mostly Relevant: The summary captures most of the main ideas of the source. 11 | 5. Highly Relevant: The summary captures all the main ideas of the source perfectly. 12 | 13 | # Evaluation Steps: 14 | 1. Thoroughly read the source document. 15 | 2. Carefully read the generated summary and compare it with the source document. 16 | 3. Compare the main ideas captured in the summary to the main ideas from the source document. 17 | 4. Rate the relevance of the summary based on how well it captures the main ideas from the source document using the 1-5 18 | scale mentioned in Evaluation Criteria. 19 | 20 | # Source Document: 21 | 22 | {{Document}} 23 | 24 | # Generated Summary: 25 | 26 | {{Summary}} 27 | 28 | # Evaluation Form (scores ONLY): -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 2 | export VLLM_WORKER_MULTIPROC_METHOD=spawn 3 | 4 | task=Attributing 5 | model_name_or_path=meta-llama/Meta-Llama-3.1-70B-Instruct 6 | # model_name_or_path=meta-llama/Meta-Llama-3.1-8B-Instruct 7 | # model_name_or_path=gemini-1.5-pro 8 | # model_name_or_path=THUDM/glm-4-9b-chat 9 | # model_name_or_path=microsoft/Phi-3.5-mini-instruct 10 | # model_name_or_path=Qwen/Qwen2.5-7B-Instruct 11 | 12 | # domain=Medicine 13 | # use_yarn=True 14 | # under_32k_only=True 15 | # over_32k_only=True 16 | 17 | use_yarn=${use_yarn:-False} 18 | under_32k_only=${under_32k_only:-False} 19 | over_32k_only=${over_32k_only:-False} 20 | 21 | if [ "$under_32k_only" = "True" ] && [ "$over_32k_only" = "True" ]; then 22 | echo "Error: Both under_32k_only and over_32k_only cannot be True simultaneously." 23 | exit 1 24 | fi 25 | 26 | cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python inference.py \ 27 | --task $task \ 28 | --model_name_or_path $model_name_or_path" 29 | 30 | if [ "$use_yarn" = "True" ]; then 31 | cmd="$cmd --use_yarn" 32 | fi 33 | if [ "$under_32k_only" = "True" ]; then 34 | cmd="$cmd --under_32k_only" 35 | fi 36 | if [ "$over_32k_only" = "True" ]; then 37 | cmd="$cmd --over_32k_only" 38 | fi 39 | if [ -n "$domain" ]; then 40 | cmd="$cmd --domain $domain" 41 | fi 42 | if [ -n "$cache_dir" ]; then 43 | cmd="$cmd --cache_dir $cache_dir" 44 | fi 45 | 46 | cmd="$cmd --command \"$cmd\"" 47 | eval $cmd -------------------------------------------------------------------------------- /summeval_prompts/faith_detailed.txt: -------------------------------------------------------------------------------- 1 | # Instruction: 2 | Below is an instruction for evaluating the faithfulness of the generated summary to the source document. Faithfulness is 3 | the absence of factual errors in the summary, where a factual error is a statement that contradicts the source document or 4 | is not directly stated, heavily implied, or logically entailed by the source document. The goal is to score faithfulness 5 | on a scale of 1-5, with 1 being unfaithful (all information is wrong) and 5 being extremely faithful (no factual errors, 6 | directly correlate to the document). 7 | 8 | # Evaluation Criteria: 9 | 1. Unfaithful: The summary contains no factual information from the document. 10 | 2. Somewhat Unfaithful: The summary contains some factual information but several are wrong or misleading. 11 | 3. Neutral: The summary is half correct and half incorrect in terms of factual information. 12 | 4. Somewhat Faithful: The summary contains more factual information than errors but still has noticeable mistakes. 13 | 5. Extremely Faithful: The summary contains all factual information from the document with no errors. 14 | 15 | # Evaluation Steps: 16 | 1. Thoroughly read the source document. 17 | 2. Carefully read the generated summary and compare it with the source document. 18 | 3. Carefully read the summary and compare the facts presented with the facts in the source document. 19 | 4. Rate the faithfulness of the generated summary based on how faithfully the summary reflects the information in the 20 | source document using the 1-5 scale mentioned in Evaluation Criteria. 21 | 22 | # Source Document: 23 | 24 | {{Document}} 25 | 26 | # Generated Summary: 27 | 28 | {{Summary}} 29 | 30 | # Evaluation Form (scores ONLY): -------------------------------------------------------------------------------- /summeval_prompts/con_detailed.txt: -------------------------------------------------------------------------------- 1 | # Instruction: 2 | Below is an instruction for evaluating the consistency of the generated summary to the source document. Consistency measures 3 | whether a candidate summary is factually consistent with the source. The goal is to score consistency on a scale of 1-5, 4 | with 1 being completely inconsistent and 5 being completely consistent. 5 | Please consider the following seven types of errors while performing the evaluation: i) predicate in summary inconsistent 6 | with source, ii) primary arguments or its attributes are wrong, iii) predicate’s circumstantial information is wrong, iv) 7 | co-reference error, v) multiple sentences linked incorrectly, vi) out of document error and vii) unreadable sentence(s) due 8 | to grammatical errors. 9 | 10 | # Evaluation Criteria: 11 | 1. Completely Inconsistent - The summary contains multiple factual errors or inaccuracies in relation to the source 12 | document. 13 | 2. Mostly Inconsistent - The summary contains several factual errors but retains some accurate information from the 14 | source. 15 | 3. Somewhat Consistent - The summary contains a mix of accurate and inaccurate information. Factual errors are present 16 | but not overwhelming. 17 | 4. Mostly Consistent - The summary is largely accurate, with few factual errors or inaccuracies. 18 | 5. Completely Consistent - The summary accurately represents all the information presented in the source document without 19 | any factual error. 20 | 21 | # Evaluation Steps: 22 | 1. Thoroughly read the source document. 23 | 2. Carefully read the generated summary and compare it with the source document. 24 | 3. Rate the consistency of the generated summary based on the provided types of errors using the 1-5 scale mentioned in 25 | Evaluation Criteria. 26 | 27 | # Source Document: 28 | 29 | {{Document}} 30 | 31 | # Generated Summary: 32 | 33 | {{Summary}} 34 | 35 | # Evaluation Form (scores ONLY): -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ETHIC: Evaluating Large Language Models on Long-Context Tasks with High Information Coverage 2 | 3 |

4 | 📃 Paper | 🤗 Dataset 5 |

6 | 7 | ## 📋 Introduction 8 | **ETHIC** is a long-context benchmark designed to assess whether LLMs can fully utilize the provided information. ETHIC comprises tasks with high **Information Coverage (IC)** scores (~91%), i.e. the proportion of input context necessary for answering queries. 9 | 10 | ![](figs/long_context_figure.JPG) 11 | 12 | ## ⚒️ Setup 13 | We recommend using the following versions for compatibility. 14 | * PyTorch 2.4.0 15 | * Cuda 12.1 16 | ```shell 17 | # create a new environment 18 | conda create -n ethic python==3.9.19 19 | conda activate ethic 20 | 21 | # install required packages 22 | pip install -r requirements.txt 23 | ``` 24 | ## ⏩ Quickstart 25 | To use our dataset directly, simply download it using 🤗 Datasets: 26 | 27 | ```python 28 | from datasets import load_dataset 29 | 30 | task = "Recalling" # Choose from "Recalling", "Summarizing", "Organizing", "Attributing" 31 | dataset = load_dataset("dmis-lab/ETHIC", task)["test"] 32 | ``` 33 | 34 | For model inference and evaluation, prepare your OpenAI API key (or other keys for authorization) in _api_config.py_, as we utilize `gpt-4o` in the _Summarizing_ task. Also, Qwen2.5 recommends utilizing YaRN for inputs exceeding 32,768 tokens. Make sure to run inference twice: {_use_yarn_=True, _over_32k_only_=True} and {_use_yarn_=False, _under_32k_only_=True}. 35 | ```shell 36 | # run.sh 37 | 38 | CUDA_VISIBLE_DEVICES=1 39 | export VLLM_WORKER_MULTIPROC_METHOD=spawn 40 | 41 | task=Attributing # Recalling, Summarizing, Organizing, Attributing 42 | model_name_or_path=meta-llama/Meta-Llama-3.1-8B-Instruct 43 | 44 | # use_yarn=True 45 | # under_32k_only=True 46 | # over_32k_only=True 47 | # domain=Medicine 48 | 49 | cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python inference.py \ 50 | --task $task \ 51 | --model_name_or_path $model_name_or_path" 52 | 53 | if [ "$use_yarn" = "True" ]; then 54 | cmd="$cmd --use_yarn" 55 | fi 56 | if [ "$under_32k_only" = "True" ]; then 57 | cmd="$cmd --under_32k_only" 58 | fi 59 | if [ "$over_32k_only" = "True" ]; then 60 | cmd="$cmd --over_32k_only" 61 | fi 62 | if [ -n "$domain" ]; then 63 | cmd="$cmd --domain $domain" 64 | fi 65 | if [ -n "$cache_dir" ]; then 66 | cmd="$cmd --cache_dir $cache_dir" 67 | fi 68 | 69 | eval $cmd 70 | ``` 71 | 72 | 73 | ## Citation 74 | ``` 75 | @article{lee2024ethic, 76 | title={ETHIC: Evaluating Large Language Models on Long-Context Tasks with High Information Coverage}, 77 | author={Lee, Taewhoo and Yoon, Chanwoong and Jang, Kyochul and Lee, Donghyeon and Song, Minju and Kim, Hyunjae and Kang, Jaewoo}, 78 | journal={arXiv preprint arXiv:2410.16848}, 79 | year={2024} 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | import sys 5 | from pathlib import Path 6 | from transformers import AutoTokenizer 7 | from datasets import load_dataset 8 | from tqdm import tqdm 9 | import argparse 10 | from vllm import LLM, SamplingParams 11 | import google.generativeai as genai 12 | from google.generativeai.types import HarmCategory, HarmBlockThreshold 13 | from openai import OpenAI 14 | import tiktoken 15 | 16 | from utils import get_logger, get_model_prompts, calculate_score, create_batch_for_summarizing, run_batch_for_summarizing, parse_score_for_summarizing, write_score_file, count_tokens_for_gpt 17 | from api_config import CONFIG 18 | 19 | def main(args): 20 | 21 | # set logger 22 | model_name = os.path.basename(args.model_name_or_path) 23 | 24 | path_to_logdir = os.path.join(args.log_path, model_name, args.task) 25 | os.makedirs(path_to_logdir, exist_ok=True) 26 | logger = get_logger(logger_name=__name__, path_to_logdir=path_to_logdir) 27 | 28 | logger.info(f"Running command: {args.command}") 29 | used_gpus = next((part.split("=")[1] for part in args.command.split() if part.startswith("CUDA_VISIBLE_DEVICES=")), "") 30 | gpu_count = len(used_gpus.split(",")) if used_gpus else 0 31 | 32 | dataset = load_dataset("dmis-lab/ETHIC", args.task, cache_dir=args.cache_dir)["test"] 33 | 34 | logger.info(f"Loaded dataset for task {args.task}") 35 | 36 | save_path = os.path.join(args.save_path, model_name, args.task) 37 | os.makedirs(save_path, exist_ok=True) 38 | 39 | if args.domain: 40 | os.makedirs(os.path.join(save_path, args.domain), exist_ok=True) 41 | else: 42 | os.makedirs(os.path.join(save_path, "Books"), exist_ok=True) 43 | os.makedirs(os.path.join(save_path, "Debates"), exist_ok=True) 44 | os.makedirs(os.path.join(save_path, "Medicine"), exist_ok=True) 45 | os.makedirs(os.path.join(save_path, "Law"), exist_ok=True) 46 | 47 | prompt = get_model_prompts(args.model_name_or_path) 48 | 49 | # load model 50 | if "gpt" in args.model_name_or_path: 51 | client = OpenAI(api_key=CONFIG["openai"][0]) 52 | elif "gemini" in args.model_name_or_path: 53 | genai.configure(api_key=CONFIG["google"][0]) 54 | model = genai.GenerativeModel(args.model_name_or_path) 55 | elif args.use_yarn: 56 | 57 | logger.info(f"Loading model with yarn.") 58 | 59 | model = LLM(model=args.model_name_or_path, download_dir=args.cache_dir, rope_scaling={"factor":4.0, "original_max_position_embeddings": 32768, "type": "yarn"}, trust_remote_code=True, tensor_parallel_size=gpu_count) 60 | sampling_params = SamplingParams(temperature=0, top_p=1.0, max_tokens=4096) 61 | else: # vllm 62 | model = LLM(model=args.model_name_or_path, download_dir=args.cache_dir, trust_remote_code=True, tensor_parallel_size=gpu_count) 63 | sampling_params = SamplingParams(temperature=0, top_p=1.0, max_tokens=4096) 64 | 65 | logger.info(f"Loaded model, saving model predictions to {save_path}") 66 | 67 | if args.under_32k_only: 68 | logger.info(f"Predictions for samples less than 32768 tokens only") 69 | if args.over_32k_only: 70 | logger.info(f"Predictions for samples more than 32768 tokens only") 71 | if args.domain: 72 | logger.info(f"Predictions for domain {args.domain} only") 73 | 74 | dataset_tqdm = tqdm(dataset, file=open(os.devnull, "w")) 75 | for sample in dataset_tqdm: 76 | 77 | id_ = sample["ID"] 78 | answer = sample["Answer"] 79 | system_msg = sample["System_msg"] 80 | user_msg = sample["User_msg"] 81 | domain = sample["Domain"] 82 | 83 | if args.domain and args.domain != domain: 84 | logger.info(f"skipping domain {domain}") 85 | continue 86 | 87 | logger.info(f"{str(dataset_tqdm)} Domain: {domain}, ID: {id_}") 88 | 89 | if "gemini" in args.model_name_or_path: 90 | full_prompt = prompt.format(system_msg=system_msg, user_msg=user_msg) 91 | full_prompt_length = model.count_tokens(full_prompt).total_tokens 92 | if args.under_32k_only and full_prompt_length > 32768: 93 | logger.info(f"skipping: {full_prompt_length} > 32768") 94 | continue 95 | if args.over_32k_only and full_prompt_length <= 32768: 96 | logger.info(f"skipping: {full_prompt_length} <= 32768") 97 | continue 98 | response = model.generate_content( 99 | full_prompt, 100 | generation_config=genai.types.GenerationConfig( 101 | candidate_count=1, 102 | max_output_tokens=4096, 103 | temperature=0.0 104 | ), 105 | safety_settings={ 106 | HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, 107 | HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, 108 | HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, 109 | HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, 110 | } 111 | ) 112 | try: 113 | prediction = response.text 114 | except ValueError: # gemini models occasionally refuse to answer 115 | logger.warning("Prediction FAILED") 116 | prediction = "FAILED" 117 | elif "gpt" in args.model_name_or_path: 118 | messages = [ 119 | {"role" : "system", "content": system_msg}, 120 | {"role":"user", "content":user_msg} 121 | ] 122 | full_prompt_length = count_tokens_for_gpt(messages, args.model_name_or_path) 123 | if args.under_32k_only and full_prompt_length > 32768: 124 | logger.info(f"skipping: {full_prompt_length} > 32768") 125 | continue 126 | if args.over_32k_only and full_prompt_length <= 32768: 127 | logger.info(f"skipping: {full_prompt_length} <= 32768") 128 | continue 129 | completion = client.chat.completions.create( 130 | model=args.model_name_or_path, 131 | messages=messages, 132 | temperature=0, 133 | top_p=1.0, 134 | max_tokens=4096, 135 | ) 136 | prediction = completion.choices[0].message 137 | elif "Qwen" in args.model_name_or_path or "glm" in args.model_name_or_path: 138 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) 139 | messages=[ 140 | {"role" : "system", "content": system_msg}, 141 | {"role":"user", "content":user_msg} 142 | ] 143 | full_prompt = tokenizer.apply_chat_template( 144 | messages, 145 | tokenize=False, 146 | add_generation_prompt=True 147 | ) 148 | 149 | full_prompt_length = len(tokenizer.encode(full_prompt)) 150 | if args.under_32k_only and full_prompt_length > 32768: 151 | logger.info(f"skipping: {full_prompt_length} > 32768") 152 | continue 153 | if args.over_32k_only and full_prompt_length <= 32768: 154 | logger.info(f"skipping: {full_prompt_length} <= 32768") 155 | continue 156 | outputs = model.generate([full_prompt], sampling_params) 157 | for output in outputs: 158 | prediction = output.outputs[0].text 159 | else: 160 | full_prompt = prompt.format(system_msg=system_msg, user_msg=user_msg) 161 | outputs = model.generate([full_prompt], sampling_params) 162 | for output in outputs: 163 | prediction = output.outputs[0].text 164 | 165 | result_dict, score = calculate_score(args.task, user_msg, prediction, answer) 166 | 167 | with open(os.path.join(save_path, domain, f"{id_}.json"), "w") as wf: 168 | json.dump(result_dict, wf) 169 | 170 | # for Summarizing task, score using batch inference 171 | if args.task == "Summarizing": 172 | 173 | logger.info("Preparing for summary scoring (batch inference)") 174 | 175 | path_list = [str(f) for f in Path(save_path).rglob("*.json") if f.parent.name in ["Books", "Debates", "Law", "Medicine"]] 176 | batch_for_summarizing = create_batch_for_summarizing(path_list) 177 | 178 | batch_input_path = os.path.join(os.path.dirname(save_path), "batch_inference", "summarizing_input.jsonl") 179 | if os.path.exists(batch_input_path): 180 | logger.error(f"Batch file for {model_name} already exists!") 181 | raise ValueError() 182 | 183 | os.makedirs(os.path.dirname(batch_input_path), exist_ok=True) 184 | 185 | with open(batch_input_path, "a") as wf: 186 | for line in batch_for_summarizing: 187 | wf.write(json.dumps(line) + "\n") 188 | 189 | logger.info("Running batch inference") 190 | 191 | try: 192 | batch_output_path = run_batch_for_summarizing(batch_input_path) 193 | except ValueError: 194 | logger.error("Batch inference FAILED") 195 | sys.exit(1) 196 | 197 | logger.info("Batch inference COMPLETE") 198 | 199 | score_dict = parse_score_for_summarizing(batch_output_path) 200 | for domain_filename in score_dict: 201 | domain = domain_filename[:domain_filename.find("_")] 202 | filename = domain_filename[domain_filename.find("_")+1:] 203 | 204 | filepath = os.path.join(save_path, domain, f"{filename}.json") 205 | with open(filepath) as rf: 206 | orig_dict = json.load(rf) 207 | 208 | prediction = orig_dict["prediction"] 209 | input_sections = orig_dict["input_sections"] 210 | score = score_dict[domain_filename]["weighted"] 211 | 212 | with open(filepath, "w") as wf: 213 | json.dump({ 214 | "prediction": prediction, 215 | "input_sections": input_sections, 216 | "score": score 217 | }, wf) 218 | 219 | # write score files ONLY WHEN the entire test set is completed 220 | if not args.under_32k_only and not args.over_32k_only and not args.domain: 221 | write_score_file(args.task, save_path) 222 | 223 | logger.info("All done!") 224 | 225 | if __name__ == "__main__": 226 | parser = argparse.ArgumentParser() 227 | parser.add_argument("--task", type=str, required=True, help="Choose from [\"Recalling\", \"Summarizing\", \"Organizing\", \"Attributing\"]") 228 | parser.add_argument("--model_name_or_path", type=str, required=True) 229 | parser.add_argument("--cache_dir", type=str, required=False) 230 | parser.add_argument("--domain", type=str, required=False) 231 | parser.add_argument("--use_yarn", action="store_true") 232 | parser.add_argument("--under_32k_only", action="store_true") 233 | parser.add_argument("--over_32k_only", action="store_true") 234 | parser.add_argument("--save_path", type=str, default=os.path.join(os.path.abspath(os.path.dirname(__file__)), "results")) 235 | parser.add_argument("--log_path", type=str, default=os.path.join(os.path.abspath(os.path.dirname(__file__)), "logs")) 236 | parser.add_argument("--command", type=str, help="The command that was run") 237 | 238 | args = parser.parse_args() 239 | main(args) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | import glob 5 | import time 6 | from openai import OpenAI 7 | import numpy as np 8 | import logging 9 | from datetime import datetime 10 | from api_config import CONFIG 11 | import tiktoken 12 | 13 | def get_logger(logger_name, path_to_logdir): 14 | 15 | logger = logging.getLogger(logger_name) 16 | logger.setLevel(logging.INFO) 17 | 18 | current_datetime = datetime.now() 19 | formatted_datetime = current_datetime.strftime("%y%m%d-%H%M") 20 | path_to_logfile = os.path.join(path_to_logdir, f"{formatted_datetime}.log") 21 | 22 | if not logger.hasHandlers(): 23 | file_handler = logging.FileHandler( 24 | path_to_logfile, 25 | mode="a", 26 | encoding="utf-8" 27 | ) 28 | formatter = logging.Formatter( 29 | "[%(asctime)s] %(levelname)s: %(message)s", 30 | datefmt="%Y-%m-%d %H:%M", 31 | ) 32 | 33 | file_handler.setFormatter(formatter) 34 | logger.addHandler(file_handler) 35 | 36 | return logger 37 | 38 | def count_tokens_for_gpt(messages, model): 39 | 40 | """Return the number of tokens used by a list of messages.""" 41 | try: 42 | encoding = tiktoken.encoding_for_model(model) 43 | except KeyError: 44 | print("Warning: model not found. Using o200k_base encoding.") 45 | encoding = tiktoken.get_encoding("o200k_base") 46 | if model in { 47 | "gpt-3.5-turbo-0125", 48 | "gpt-4-0314", 49 | "gpt-4-32k-0314", 50 | "gpt-4-0613", 51 | "gpt-4-32k-0613", 52 | "gpt-4o-mini-2024-07-18", 53 | "gpt-4o-2024-08-06" 54 | }: 55 | tokens_per_message = 3 56 | tokens_per_name = 1 57 | elif "gpt-3.5-turbo" in model: 58 | print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.") 59 | return count_tokens_for_gpt(messages, model="gpt-3.5-turbo-0125") 60 | elif "gpt-4o-mini" in model: 61 | print("Warning: gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-mini-2024-07-18.") 62 | return count_tokens_for_gpt(messages, model="gpt-4o-mini-2024-07-18") 63 | elif "gpt-4o" in model: 64 | print("Warning: gpt-4o and gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-2024-08-06.") 65 | return count_tokens_for_gpt(messages, model="gpt-4o-2024-08-06") 66 | elif "gpt-4" in model: 67 | print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") 68 | return count_tokens_for_gpt(messages, model="gpt-4-0613") 69 | else: 70 | raise NotImplementedError( 71 | f"""count_tokens_for_gpt() is not implemented for model {model}.""" 72 | ) 73 | num_tokens = 0 74 | for message in messages: 75 | num_tokens += tokens_per_message 76 | for key, value in message.items(): 77 | num_tokens += len(encoding.encode(value)) 78 | if key == "name": 79 | num_tokens += tokens_per_name 80 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 81 | return num_tokens 82 | 83 | def calculate_f1_score(model_answer, label_list): 84 | 85 | model_list = re.split(r"[;,]\s*", model_answer) 86 | 87 | model_list = sorted(set([pred.lower().strip(".").strip() for pred in model_list])) 88 | label_list = sorted([label.lower().strip(".").strip() for label in label_list]) 89 | 90 | num_labels = len(label_list) 91 | tp = 0 92 | for pred in model_list: 93 | for label in label_list[:]: 94 | if pred == label: 95 | tp += 1 96 | break 97 | 98 | fp = len(model_list) - tp 99 | fn = num_labels - tp 100 | 101 | precision = tp / (tp + fp) if (tp + fp) > 0 else 0 102 | recall = tp / (tp + fn) if (tp + fn) > 0 else 0 103 | f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 104 | 105 | # return f1_score 106 | return precision, recall, f1_score 107 | 108 | def calculate_lcs(prediction, answer): 109 | 110 | m = len(prediction) 111 | n = len(answer) 112 | 113 | L = [[0] * (n + 1) for _ in range(m + 1)] 114 | 115 | for i in range(m + 1): 116 | for j in range(n + 1): 117 | if i == 0 or j == 0: 118 | L[i][j] = 0 119 | elif prediction[i-1] == answer[j-1]: 120 | L[i][j] = L[i-1][j-1] + 1 121 | else: 122 | L[i][j] = max(L[i-1][j], L[i][j-1]) 123 | 124 | index = L[m][n] 125 | 126 | lcs_sequence = [""] * index 127 | 128 | i, j = m, n 129 | while i > 0 and j > 0: 130 | if prediction[i-1] == answer[j-1]: 131 | lcs_sequence[index-1] = prediction[i-1] 132 | i -= 1 133 | j -= 1 134 | index -= 1 135 | elif L[i-1][j] > L[i][j-1]: 136 | i -= 1 137 | else: 138 | j -= 1 139 | 140 | lcs_length = len(lcs_sequence) 141 | lcs_score = lcs_length / len(answer) 142 | 143 | return lcs_sequence, lcs_score 144 | 145 | def create_batch_for_summarizing(path_list): 146 | 147 | with open("./summeval_prompts/con_detailed.txt") as rf: 148 | prompt_con = rf.read() 149 | with open("./summeval_prompts/faith_detailed.txt") as rf: 150 | prompt_faith = rf.read() 151 | with open("./summeval_prompts/rel_detailed.txt") as rf: 152 | prompt_rel = rf.read() 153 | criteria_dict = {prompt_con: 'con', prompt_faith: 'faith', prompt_rel: 'rel'} 154 | 155 | batch_list = [] 156 | for path in path_list: 157 | with open(path) as rf: 158 | pred_dict = json.load(rf) 159 | 160 | # 1. prepare section-wise context / prediction 161 | domain = os.path.basename(os.path.dirname(path)) 162 | filename = os.path.basename(path).replace(".json", "") 163 | 164 | if domain == "Law": 165 | section_pattern = re.compile("()") 166 | else: 167 | section_pattern = re.compile("(
)") 168 | 169 | section_context_dict = dict() 170 | section_pred_dict = dict() 171 | orig_sections = section_pattern.split(pred_dict["input_sections"]) 172 | summ_sections = section_pattern.split(pred_dict["prediction"]) 173 | 174 | for i in range(1, len(orig_sections), 2): 175 | section_context_dict[orig_sections[i]] = orig_sections[i+1].strip() 176 | 177 | for i in range(1, len(summ_sections), 2): 178 | section_pred_dict[summ_sections[i]] = summ_sections[i+1].strip() 179 | 180 | # 2. create batch using 3 different criteria per section 181 | for prompt in [prompt_con, prompt_faith, prompt_rel]: 182 | for section in section_context_dict: 183 | if section not in section_pred_dict: # model did not create summary for the section 184 | continue 185 | prompt_with_content = prompt.replace('{{Document}}', section_context_dict[section]).replace('{{Summary}}', section_pred_dict[section]) 186 | batch = { 187 | 'custom_id': f"{domain}_{filename}_{section}_{criteria_dict[prompt]}", 188 | 'method': 'POST', 189 | 'url': "/v1/chat/completions", 190 | 'body': { 191 | 'model': 'gpt-4o-2024-08-06', 192 | 'messages': [{"role": "system", "content": prompt_with_content}], 193 | 'temperature': 0, 194 | 'max_tokens': 5, 195 | 'top_p': 1, 196 | 'frequency_penalty': 0, 197 | 'presence_penalty': 0, 198 | 'stop': None, 199 | 'logprobs': True, 200 | 'top_logprobs': 10, 201 | 'n': 1 202 | } 203 | } 204 | 205 | batch_list.append(batch) 206 | 207 | return batch_list 208 | 209 | def run_batch_for_summarizing(batch_input_path): 210 | 211 | batch_output_path = os.path.join(os.path.dirname(batch_input_path), "summarizing_output.jsonl") 212 | 213 | client = OpenAI(api_key=CONFIG["openai"][0]) 214 | batch_input_file = client.files.create( 215 | file=open(batch_input_path, "rb"), 216 | purpose="batch" 217 | ) 218 | 219 | batch_job = client.batches.create( 220 | input_file_id=batch_input_file.id, 221 | endpoint="/v1/chat/completions", 222 | completion_window="24h" 223 | ) 224 | time.sleep(10) 225 | 226 | # retrieve batch information 227 | retrieved_batch_job = client.batches.retrieve(batch_job.id) 228 | 229 | while True: 230 | time.sleep(30) # wait for 30 seconds for another status request 231 | retrieved_batch_job = client.batches.retrieve(batch_job.id) 232 | if retrieved_batch_job.status == 'completed' or retrieved_batch_job.status == 'failed': 233 | break 234 | 235 | if retrieved_batch_job.status == 'failed': 236 | raise ValueError() 237 | 238 | result_file_id = retrieved_batch_job.output_file_id 239 | result = client.files.content(result_file_id).text 240 | 241 | time.sleep(10) 242 | 243 | with open(batch_output_path, "w") as wf: 244 | wf.write(result) 245 | 246 | return batch_output_path 247 | 248 | def parse_score_for_summarizing(batch_output_path): 249 | 250 | batch_outputs = [] 251 | with open(batch_output_path) as rf: 252 | for line in rf: 253 | batch_outputs.append(json.loads(line)) 254 | 255 | samples = dict() 256 | for batch_output in batch_outputs: 257 | custom_id = batch_output["custom_id"] # {domain}_{filename}_{section}_{criteria} 258 | domain = custom_id.split("_")[0] 259 | section_format_text = "_ 5: 287 | continue 288 | 289 | logprob = tokens.get('logprob', float('-inf')) 290 | prob = np.exp(logprob) 291 | scores_dict[score] += prob 292 | 293 | for score, prob in scores_dict.items(): 294 | samples[sample_id][f'weighted_{criteria}'] += score * prob 295 | 296 | samples[sample_id]['count'] += 1 297 | 298 | # Average scores 299 | for sample in samples: 300 | samples[sample]['count'] /= 3 301 | 302 | for score in samples[sample]: 303 | samples[sample][score] /= samples[sample]['count'] 304 | 305 | samples[sample]['weighted'] = sum(samples[sample][feature] for feature in samples[sample] if 'weighted' in feature) / 3 306 | samples[sample]['top'] = sum(samples[sample][feature] for feature in samples[sample] if 'top' in feature) / 3 307 | 308 | return samples 309 | 310 | def calculate_score(task, user_msg, prediction, answer): 311 | 312 | result_dict = dict() 313 | result_dict["prediction"] = prediction 314 | result_dict["answer"] = answer 315 | 316 | if task == "Recalling": 317 | if prediction == "FAILED": 318 | result_dict["precision"], result_dict["recall"], result_dict["f1_score"] = 0, 0, 0 319 | else: 320 | result_dict["precision"], result_dict["recall"], result_dict["f1_score"] = calculate_f1_score(prediction, answer) 321 | score = result_dict["f1_score"] 322 | elif task == "Summarizing": 323 | input_sections_or_segments = re.search("### Context:\n(.+?)\n\nNow, respond to the instruction", user_msg, re.DOTALL).group(1) 324 | result_dict["input_sections"] = input_sections_or_segments 325 | score = 0 # score will be calculated separately 326 | elif task == "Organizing": 327 | if prediction == "FAILED": 328 | result_dict["lcs"], result_dict["lcs_score"] = 0 329 | else: 330 | pred_in_list = re.findall("\d+", prediction) 331 | answer_in_list = re.findall("\d+", answer) 332 | result_dict["lcs"], result_dict["lcs_score"] = calculate_lcs(pred_in_list, answer_in_list) 333 | score = result_dict["lcs_score"] 334 | elif task == "Attributing": 335 | if prediction == "FAILED": 336 | result_dict["precision"], result_dict["recall"], result_dict["f1_score"] = 0, 0, 0 337 | else: 338 | match = re.search(r"(Related Segments|Core IDs):\s*(.+)", prediction) 339 | if match: # model has followed format instruction 340 | target_span = match.group(2) 341 | else: 342 | target_span = prediction 343 | 344 | pred_numbers = ", ".join(set(re.findall("\d+", target_span))) 345 | answer_numbers = [re.search("\d+", ans).group() if re.search("\d+", ans) else "None" for ans in answer] 346 | 347 | if pred_numbers == []: 348 | pred_numbers = "None" 349 | 350 | result_dict["precision"], result_dict["recall"], result_dict["f1_score"] = calculate_f1_score(pred_numbers, answer_numbers) 351 | score = result_dict["f1_score"] 352 | 353 | return result_dict, score 354 | 355 | def write_score_file(task, save_path): 356 | 357 | if task == "Recalling": 358 | score_per_domain = { 359 | "Books":[], 360 | "Debates":[], 361 | "Medicine":[], 362 | "Law":[] 363 | } 364 | metric = "f1_score" 365 | elif task == "Summarizing": 366 | score_per_domain = { 367 | "Books":[], 368 | "Debates":[], 369 | "Medicine":[], 370 | "Law":[] 371 | } 372 | metric = "score" 373 | elif task == "Organizing": 374 | score_per_domain = { 375 | "Books":[], 376 | "Debates":[], 377 | } 378 | metric = "lcs_score" 379 | elif task == "Attributing": 380 | score_per_domain = { 381 | "Medicine":[], 382 | "Law":[] 383 | } 384 | metric = "f1_score" 385 | 386 | scores = [] 387 | for domain in score_per_domain.keys(): 388 | pred_paths = glob.glob(os.path.join(save_path, domain, "*.json")) 389 | for pred_path in pred_paths: 390 | with open(pred_path) as rf: 391 | pred_dict = json.load(rf) 392 | score = pred_dict[metric] 393 | scores.append(score) 394 | score_per_domain[domain].append(score) 395 | 396 | # write score file (overall / per domain) 397 | avg_score = sum(scores) / len(scores) 398 | avg_score_per_domain = {key: sum(value) / len(value) for key, value in score_per_domain.items()} 399 | 400 | with open(os.path.join(save_path, "final_score.txt"), "w") as wf: 401 | wf.write(str(avg_score)) 402 | with open(os.path.join(save_path, "domain_score.json"), "w") as wf: 403 | json.dump(avg_score_per_domain, wf) 404 | 405 | def get_model_prompts(model_name_or_path): 406 | 407 | if "gemini" in model_name_or_path: 408 | prompt = "{system_msg}\n\n{user_msg}" 409 | elif "Llama-3.1" in model_name_or_path: 410 | prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_msg}<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>" 411 | elif "Phi" in model_name_or_path: 412 | prompt = "<|system|>\n{system_msg}<|end|>\n<|user|>\n{user_msg}<|end|>\n<|assistant|>" 413 | else: # gpt, qwen, glm receives "messages" list as input 414 | prompt = "" 415 | return prompt --------------------------------------------------------------------------------