├── .gitignore ├── Main_Fig_Horizontal.png ├── Main_Figure.png ├── Main_Figure_Horizontal.png ├── README.md ├── dataset.py ├── datasets_classes ├── __init__.py ├── base.py ├── dataset_loader.py └── qa │ ├── MuSiQue.py │ ├── __init__.py │ └── __pycache__ │ ├── MuSiQue.cpython-310.pyc │ └── __init__.cpython-310.pyc ├── evaluation_classes ├── MusiQue_metrics │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── answer.cpython-310.pyc │ │ └── metric.cpython-310.pyc │ ├── answer.py │ ├── group.py │ ├── group_answer_sufficiency.py │ ├── group_support_sufficiency.py │ ├── metric.py │ └── support.py ├── __init__.py ├── coreference.py ├── eval_base_class.py ├── fuse_reviews.py ├── question_answering.py └── summarization.py ├── files ├── configuration │ ├── predict.json │ └── predict_c.json └── prompts │ ├── ECB.json │ ├── FuseReviews.json │ ├── MultiNews.json │ ├── MusiQue.json │ ├── OpenASP.json │ └── SciCo.json ├── model_wrappers ├── __init__.py ├── hf_pipline_wrap.py └── model_wrap.py ├── requirements.txt └── scripts ├── Lama_models.py ├── __init__.py ├── check_pkl.py ├── compare_results_jsons.py ├── create_various_sets.py ├── eval_all.sh ├── evaluate_dataset.py ├── evaluation ├── ASD.py ├── __init__.py ├── cal_f1.py ├── cal_f1_fittokens.py ├── cal_f1_re.py ├── compare_models.py ├── sample_length_function.py └── statistical_analysis.py ├── generate_benchmark.py ├── graph.py ├── results_processing └── . … ├── run_all.sh └── run_model_predictions.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/"__pycache__/" 2 | __pycache__/ 3 | "__pycache__/" 4 | -------------------------------------------------------------------------------- /Main_Fig_Horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/Main_Fig_Horizontal.png -------------------------------------------------------------------------------- /Main_Figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/Main_Figure.png -------------------------------------------------------------------------------- /Main_Figure_Horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/Main_Figure_Horizontal.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

More Documents, Same Length:
Isolating the Challenge of Multiple Documents in RAG

3 |

Link to arXiv

4 |
5 | 6 | This repository contains code and datasets for our paper on the effects of document multiplicity while the context size is fixed in Retrieval-Augmented Generation (RAG) systems. 7 | For detailed methodology, experiments, and analysis, please refer to the full paper 📰 8 | 9 | ## :bulb: High-Level Conclusions 10 | Our results show that adding more retrieved documents can hurt performance—up to a 10% drop in fixed-context setups—making document rich retrieval tasks harder. 11 | Llama-3.1 and Gemma-2 declined, Qwen-2 stayed steady, and smaller LLMs (7–9B) followed the trend less strongly. This suggests systems need to balance relevance and variety to cut conflicts, and future models might improve by filtering out contradictory details while using the range of documents. 12 | 13 | ## 🔬 Our Methodology: 14 |
15 | Starting with a Wikipedia-derived dataset, we created different sets with the same amount of tokens but fewer documents by adjusting the length of the key documents for each question. 16 | Our sets use the same multi-hop questions and supporting documents with key info (pink) , while varying distractor documents (blue). 17 | We began with 20 documents, then omitted redundant ones while lengthening the remaining ones to match the original size. 18 |
19 |
20 |
21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 | Alt text 29 |
30 | 31 | 32 | ## :desktop_computer: Reproduction Instructions: 33 | 34 | ### Download the different benchmark datasets 35 | Our custom benchmark datasets include a control set, the original dataset, and variants with replaced distractors for varying document multiplicity. 36 | You can Download them from [here](https://drive.google.com/file/d/1z6L0Xl0zhRoOOpwD5WuQI9ukSaEgCraM/view?usp=drive_link), or from [Hugging Face](https://huggingface.co/datasets/Shahar6000/MoreDocsSameLen). 37 | 38 | Alternatively, regenerate them using [`scripts/create_various_sets.py`](scripts/create_various_sets.py). 39 | 40 | ### Prepare the environment 41 | 42 | To set up the running environment, run the following command: 43 | ``` 44 | gh repo clone shaharl6000/MoreDocsSameLen 45 | cd MoreDocsSameLen 46 | export PYTHONPATH=./ 47 | python3.11 -m venv venv 48 | source venv/bin/activate 49 | pip install -r requirements.txt 50 | 51 | ``` 52 | 53 | ### Run predictions 54 | For running in inference on the chosen benchmark dataset you need to define for each benchmark data set a config file under configuration folder [`files/configuration/predict.json`](files/configuration/predict.json). 55 | 56 | The `predict.json` file contains the path to the generated benchmark from previous step, the batch size, and the decoding temperature for the LLMs. 57 | 58 | We supply two option for running the code with small models (the code run locally), with large model (the code run with Together platform) 59 | 60 | To run prediction with the small models, run the following command: 61 | ```bash 62 | python scripts/run_model_predictions.py --config --model_name 63 | ``` 64 | 65 | For the large model add ['together_api_key.py'](together_api_key.py) under the root path and define: API_KEY = XXXXX 66 | 67 | then run the following command: 68 | 69 | ```bash 70 | python scripts/run_model_predictions.py --config --model_name --run_together 71 | ``` 72 | 73 | ### Evaluate the predictions 74 | 75 | To evaluate the predictions, you can use [`scripts/evaluate_dataset.py`](scripts/evaluate_dataset.py) by providing 76 | the path to the predictions from previous step, and output path where all results will be saved. 77 | 78 | ```bash 79 | python scripts/evaluate_dataset.py --predictions_dir --output_path --ds_name MusiQue 80 | ``` 81 | 82 | ## :newspaper: Citation 83 | 84 | If you use this code or the datasets in your research, please cite: 85 | 86 | ``` 87 | @misc{levy2025documentslengthisolatingchallenge, 88 | title={More Documents, Same Length: Isolating the Challenge of Multiple Documents in RAG}, 89 | author={Shahar Levy and Nir Mazor and Lihi Shalmon and Michael Hassid and Gabriel Stanovsky}, 90 | year={2025}, 91 | eprint={2503.04388}, 92 | archivePrefix={arXiv}, 93 | primaryClass={cs.CL}, 94 | url={https://arxiv.org/abs/2503.04388}, 95 | } 96 | ``` 97 | 98 | 99 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/dataset.py -------------------------------------------------------------------------------- /datasets_classes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/datasets_classes/__init__.py -------------------------------------------------------------------------------- /datasets_classes/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import pandas as pd 4 | from datasets import load_dataset 5 | import json 6 | from transformers import AutoTokenizer 7 | import numpy as np 8 | import random 9 | from together import Together 10 | from tqdm import tqdm 11 | import time 12 | from datasets_classes.dataset_loader import MyDataset 13 | from torch.utils.data import Dataset as TorchDataset, DataLoader 14 | 15 | 16 | def collate_fn( batch): 17 | """ 18 | Custom collate function to handle batches of dictionaries. 19 | 20 | Args: 21 | batch: List of dictionaries. Each element in the batch is a dict. 22 | 23 | Returns: 24 | A list of dictionaries, with no concatenation of the values. 25 | """ 26 | return batch 27 | 28 | class Dataset: 29 | common_data = {"num_demos": 3, "max_num_samples": -1, "random": random.Random(42)} 30 | 31 | @staticmethod 32 | def update_common_data(key, value): 33 | """ A static method to update the common data dictionary. """ 34 | Dataset.common_data[key] = value 35 | 36 | def __init__(self, name, dir_path, split_name): 37 | self.dir_path = dir_path 38 | self.name = name 39 | self.all_prompts = self._get_prompts_from_json() 40 | self.cur_prompt = None 41 | self.all_data = self.load() 42 | self.max_num_docs = self.get_max_num_docs() 43 | self.shuffled_doc_ids = self.common_data["random"].sample(range(self.max_num_docs), k=self.max_num_docs) 44 | self.all_samples = self.pre_process() 45 | self._remove_long_samples(lim=-1) 46 | self.split_name = split_name 47 | if self.common_data["max_num_samples"] > -1: 48 | self._random_sampling() 49 | 50 | def get_max_num_docs(self): 51 | return max([self.all_data[k]["documents"].apply(len).max() for k in self.all_data]) 52 | 53 | def _random_sampling(self): 54 | max_num_samples = self.common_data["max_num_samples"] 55 | if max_num_samples == -1: 56 | return 57 | for k in self.all_samples: 58 | if len(self.all_samples[k]) <= max_num_samples: 59 | continue 60 | random_generator = self.common_data["random"] 61 | self.all_samples[k] = random_generator.sample(self.all_samples[k], k=max_num_samples) 62 | 63 | def _remove_long_samples(self, lim): 64 | if lim == -1: 65 | return 66 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") 67 | for k in self.all_samples: 68 | encoded_batch = [tokenizer.encode(sample['final_msgs'][1]['content']) for sample in self.all_samples[k]] 69 | lengths = [len(encoded) for encoded in encoded_batch] 70 | print(f"Max length in {k}: {max(lengths)}") 71 | print(f"Min length in {k}: {min(lengths)}") 72 | print(f"Mean length in {k}: {np.mean(lengths)}") 73 | longer_than_limit = [i for i, l in enumerate(lengths) if l > lim] 74 | if len(longer_than_limit) > 0: 75 | print(f"Removing {len(longer_than_limit)} out of {len(self.all_samples[k])} samples that are longer than {lim} tokens") 76 | self.all_samples[k] = [sample for i, sample in enumerate(self.all_samples[k]) if i not in longer_than_limit] 77 | print(f"Number of samples in {k}: {len(self.all_samples[k])}") 78 | 79 | def load(self): 80 | """ read the dataset files and return the data in a dictionary """ 81 | raise NotImplementedError("This method needs to be implemented by subclasses") 82 | 83 | def pre_process(self): 84 | """ pre-process the data and return a dictionary of samples that expects 'final_msgs' and 'target' keys """ 85 | raise NotImplementedError("This method needs to be implemented by subclasses") 86 | 87 | def _get_prompts_from_json(self): 88 | json_path = os.path.join("../files/prompts", f"{self.name}.json") 89 | with open(json_path, 'r', encoding='utf-8') as file: 90 | prompts = json.load(file) 91 | return prompts 92 | 93 | def get_prompt(self): 94 | random_generator = self.common_data['random'] 95 | if self.cur_prompt is None: 96 | random_prompt = random_generator.choice(self.all_prompts['prompts']) 97 | selected_demonstrations = random_generator.sample(self.all_prompts['demonstrations'], k=self.common_data['num_demos']) # no replacements sampling 98 | processed_demos = [] 99 | for i, sample in enumerate(selected_demonstrations, start=1): 100 | documents = sample.pop("documents") 101 | demonstration = self.get_sample(random_prompt["instructions"], documents, **sample) 102 | processed_demos.append(demonstration) 103 | random_prompt["few_shots"] = "\n\n".join(processed_demos) 104 | self.cur_prompt = random_prompt 105 | return self.cur_prompt 106 | 107 | def get_shuffled_documents(self, documents): 108 | tmp_documents = [None] * self.max_num_docs 109 | tmp_documents[:len(documents)] = documents 110 | shuffled = np.array(tmp_documents)[self.shuffled_doc_ids] 111 | new_documents = shuffled[~pd.isnull(shuffled)].tolist() 112 | return new_documents 113 | 114 | def get_sample(self, instructions, documents, **kwargs): 115 | documents = self.get_shuffled_documents(documents) 116 | doc_strings = "\n".join([f"Document {j}: ```{doc}```" for j, doc in enumerate(documents, start=1)]) 117 | target = kwargs.get("target", "") 118 | sample = (f"*Instructions*: {instructions}\n" 119 | f"*The documents*:\n{doc_strings}\n" 120 | f"*Answer*: {target}") 121 | return sample 122 | 123 | 124 | def predict(self, model, out_path, num_truncation_tokens): 125 | split_name = self.split_name 126 | all_msgs = [sample['final_msgs'] for sample in self.all_samples[split_name]] 127 | all_responses = model.batch(all_msgs, num_truncation_tokens) 128 | 129 | # all_responses = [] 130 | # print(len(all_msgs)) 131 | # if len(all_msgs) == 200: 132 | # jumps = 1 133 | # else: 134 | # jumps = 1 135 | # for i in tqdm(range(0, len(all_msgs), jumps)): 136 | # end = min((i + jumps), len(all_msgs)) 137 | # responses = model.batch(all_msgs[i:end], num_truncation_tokens) 138 | # all_responses = all_responses + responses 139 | # def predict(self, model, out_path, num_truncation_tokens): 140 | # split_name = self.split_name 141 | # all_msgs = [sample['final_msgs'] for sample in self.all_samples[split_name]] 142 | # print("len_msg", len(all_msgs)) 143 | # print("len_msg", type(all_msgs)) 144 | # print("len_msg", all_msgs[0][0].keys()) 145 | # dataset = MyDataset(all_msgs) 146 | # if len(all_msgs) == 200: 147 | # batch_size = 1 148 | # else: 149 | # batch_size = 2 150 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, collate_fn = collate_fn) 151 | # 152 | # 153 | # all_responses = [] 154 | # print(len(all_msgs)) 155 | # 156 | # for samples in tqdm(dataloader): 157 | # # end = min((i + jumps), len(all_msgs)) 158 | # # print("len_sampels",len(samples)) 159 | # # print("len_sampels",type(samples)) 160 | # # print("len_sampels",samples[0][0].keys()) 161 | # # print(samples[0]) 162 | # responses = model.batch(samples, num_truncation_tokens) 163 | # all_responses = all_responses + responses 164 | 165 | 166 | # responses = model.batch(all_msgs[2368:], num_truncation_tokens) 167 | # all_responses = all_responses + responses 168 | 169 | 170 | 171 | # all_responses = [] 172 | # for i in tqdm(range(0, 5, 1)): 173 | # responses = model.batch(all_msgs[i:i + 1], num_truncation_tokens) 174 | # all_responses = all_responses + responses 175 | # responses = model.batch(all_msgs[5:10], num_truncation_tokens) 176 | # all_responses = all_responses + responses 177 | for i, response in enumerate(all_responses): 178 | self.all_samples[split_name][i]['prediction'] = response 179 | 180 | out_df = pd.DataFrame(self.all_samples[split_name]) 181 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 182 | out_df.to_json(out_path, orient='records', indent=2) 183 | return out_df 184 | 185 | 186 | def predict_togehter(self, model_name, out_path, num_truncation_tokens): 187 | split_name = self.split_name 188 | all_msgs = [sample['final_msgs'] for sample in self.all_samples[split_name]] 189 | client = Together() 190 | 191 | # all_responses = model.batch(all_msgs, num_truncation_tokens) 192 | all_responses = [] 193 | print(len(all_msgs)) 194 | end = 0 195 | counter = 0 196 | flattened_all_msgs = [d for sublist in all_msgs for d in sublist] 197 | if len(all_msgs) == 200: 198 | jumps = 1 199 | else: 200 | jumps = 1 201 | start_time = time.time() 202 | counter = 0 203 | for i in tqdm(range(0, len(all_msgs), jumps)): 204 | # if counter > 5: 205 | # break 206 | # else: 207 | # counter += 1 208 | end = min((i + jumps), len(all_msgs)) 209 | stream = client.chat.completions.create( 210 | model=model_name, 211 | temperature=0.8, 212 | max_tokens=512, 213 | # model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", 214 | messages=flattened_all_msgs[i:end], 215 | stream=True, 216 | ) 217 | answer = '' 218 | for chunk in stream: 219 | answer = answer + chunk.choices[0].delta.content 220 | 221 | # print(answer) 222 | all_responses = all_responses + [answer] 223 | 224 | for i, response in enumerate(all_responses): 225 | self.all_samples[split_name][i]['prediction'] = response 226 | 227 | out_df = pd.DataFrame(self.all_samples[split_name]) 228 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 229 | out_df.to_json(out_path, orient='records', indent=2) 230 | return out_df 231 | 232 | 233 | def get_sample2msg(self, src_docs, **kwargs): 234 | prompt = self.get_prompt() 235 | sample_content = self.get_sample(prompt["instructions"], src_docs, **kwargs) 236 | demonstrations = prompt["few_shots"] 237 | user_message = f"{demonstrations}\n\n{sample_content}" 238 | messages = [{"role": "user", "content": user_message}] 239 | return messages 240 | 241 | def _read_files(self, suffix): 242 | f_names = os.path.join(self.dir_path, f'*.{suffix}') 243 | data_file_paths = { 244 | os.path.basename(fname).split('.')[0]: fname 245 | for fname in glob(f_names) 246 | } 247 | data_files_text = {} 248 | for file_name in data_file_paths: 249 | with open(data_file_paths[file_name], 'r') as f: 250 | data_files_text[file_name] = f.read() 251 | return data_files_text, data_file_paths 252 | 253 | @staticmethod 254 | def _load_from_hf(dataset_name, **kwargs): 255 | data_from_hf = load_dataset(dataset_name, **kwargs) 256 | data_dfs = {} 257 | for category, data in data_from_hf.items(): 258 | data_dfs[category] = pd.DataFrame(data) 259 | return data_dfs 260 | -------------------------------------------------------------------------------- /datasets_classes/dataset_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | class MyDataset(Dataset): 4 | def __init__(self, samples): 5 | """ 6 | A custom PyTorch Dataset to handle the samples and tokenize them. 7 | """ 8 | self.samples = samples 9 | def __len__(self): 10 | return len(self.samples) 11 | 12 | def __getitem__(self, idx): 13 | """ 14 | Tokenize and return a single sample with truncation. 15 | """ 16 | sample = self.samples[idx] 17 | return sample # return the tokens and the original sample for future processing 18 | 19 | -------------------------------------------------------------------------------- /datasets_classes/qa/MuSiQue.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pandas as pd 4 | from datasets_classes.base import Dataset 5 | 6 | 7 | class MusiQue(Dataset): 8 | 9 | def get_max_num_docs(self): 10 | return 20 11 | 12 | def get_sample(self, instructions, documents, **kwargs): 13 | documents = self.get_shuffled_documents(documents) 14 | question = kwargs.get("question", "") 15 | doc_strings = "\n".join([f"Document {j}: ```{doc}```" for j, doc in enumerate(documents, start=1)]) 16 | output_format = kwargs.get("output_format", "") 17 | target = kwargs.get("answer", "") 18 | sample = (f"*Instructions*:{instructions} {output_format}\n" 19 | f"*Question*: {question}? " 20 | f"*The documents*:\n{doc_strings}\n" 21 | f"*Answer*: {target}") 22 | return sample 23 | 24 | def load(self): 25 | data_files_text, data_file_paths = self._read_files('jsonl') 26 | df_musique = {} 27 | for split_name, split_data in data_files_text.items(): 28 | lines = split_data.strip().split('\n') 29 | df_musique[split_name] = pd.DataFrame([json.loads(line) for line in lines]) 30 | return df_musique 31 | 32 | def pre_process(self): 33 | all_samples = {} 34 | for split_name, split_data in self.all_data.items(): 35 | results_data = [] 36 | for sample_id, row in split_data.iterrows(): 37 | src_docs = [p['paragraph_text'] for p in row['paragraphs']] 38 | msgs = self.get_sample2msg(src_docs, question=row['question']) 39 | answer = row.get('answer', '') 40 | answerable = row.get('answerable', '') 41 | target = {"is_answerable": answerable, "answer_content": answer} 42 | results_data.append({ 43 | 'id': row['id'], 44 | 'final_msgs': msgs, 45 | 'target': target 46 | }) 47 | all_samples[split_name] = results_data 48 | return all_samples 49 | -------------------------------------------------------------------------------- /datasets_classes/qa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/datasets_classes/qa/__init__.py -------------------------------------------------------------------------------- /datasets_classes/qa/__pycache__/MuSiQue.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/datasets_classes/qa/__pycache__/MuSiQue.cpython-310.pyc -------------------------------------------------------------------------------- /datasets_classes/qa/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/datasets_classes/qa/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/evaluation_classes/MusiQue_metrics/__init__.py -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/evaluation_classes/MusiQue_metrics/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/__pycache__/answer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/evaluation_classes/MusiQue_metrics/__pycache__/answer.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/__pycache__/metric.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/evaluation_classes/MusiQue_metrics/__pycache__/metric.cpython-310.pyc -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/answer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Answer metric -- mostly taken directly from squad_tools of allennlp. 3 | """ 4 | import re 5 | import string 6 | import collections 7 | from typing import Tuple, List 8 | import ast 9 | from evaluation_classes.MusiQue_metrics.metric import Metric 10 | import json 11 | 12 | def normalize_answer(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | 15 | def remove_articles(text): 16 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 17 | return re.sub(regex, " ", text) 18 | 19 | def white_space_fix(text): 20 | return " ".join(text.split()) 21 | 22 | def remove_punc(text): 23 | exclude = set(string.punctuation) 24 | return "".join(ch for ch in text if ch not in exclude) 25 | 26 | def lower(text): 27 | return text.lower() 28 | 29 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 30 | 31 | 32 | def get_tokens(s): 33 | if not s: 34 | return [] 35 | return normalize_answer(s).split() 36 | 37 | 38 | def compute_exact(a_gold, a_pred): 39 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 40 | 41 | 42 | def compute_f1(a_gold, a_pred): 43 | gold_toks = get_tokens(a_gold) 44 | pred_toks = get_tokens(a_pred) 45 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 46 | num_same = sum(common.values()) 47 | if len(gold_toks) == 0 or len(pred_toks) == 0: 48 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 49 | return int(gold_toks == pred_toks) 50 | if num_same == 0: 51 | return 0 52 | precision = 1.0 * num_same / len(pred_toks) 53 | recall = 1.0 * num_same / len(gold_toks) 54 | f1 = (2 * precision * recall) / (precision + recall) 55 | return f1 56 | 57 | def extract_content_from_gt(input_string): 58 | input_dict = ast.literal_eval(input_string) 59 | 60 | # Extracting the value of 'answer_content' 61 | answer_content = input_dict['answer_content'] 62 | 63 | return answer_content 64 | def extract_content_pred(input_string): 65 | match = re.search(r"'answer_content':\s*'(.*?)'", input_string) 66 | 67 | if match: 68 | answer_content = match.group(1) 69 | else: 70 | answer_content = input_string 71 | 72 | return answer_content 73 | def extract_format(input_string): 74 | input_string = re.sub(r"'is_answerable':\s*\w+,\s*", "", input_string) 75 | 76 | # Using regular expression to remove 'answer_content' and its value 77 | input_string = re.sub(r"'answer_content':\s*'[^']*'", "", input_string) 78 | 79 | # Removing any leading or trailing commas and whitespace 80 | cleaned_string = input_string.strip().strip(',') 81 | 82 | return cleaned_string 83 | 84 | 85 | def compute_f1_with_content(a_gold, a_pred): 86 | # a = {'is_answerable': True, 'answer_content': 'Arna Selznick'} 87 | a_gold_content = a_gold 88 | a_pred_content = a_pred 89 | try: 90 | temp_gt = a_gold 91 | temp_pred = a_pred 92 | 93 | temp_gt = temp_gt.replace("'", '"').replace("True", "true").replace("False", "false") 94 | temp_pred = temp_pred.replace("'", '"').replace("True", "true").replace("False", "false") 95 | 96 | temp_gt = json.loads(temp_gt) 97 | temp_pred = json.loads(temp_pred) 98 | 99 | a_gold_content = str(temp_gt["answer_content"]) 100 | a_pred_content = str(temp_pred["answer_content"]) 101 | except: 102 | a_gold_content = a_gold_content.replace("{", "").replace("is_answerable","").replace('answer_content',"").replace('answer_content',"").replace('True',"").replace(',', "") 103 | a_gold_content = a_gold_content.replace('False',"").replace('}',"").replace('answer_content',"").replace(':',"") 104 | 105 | a_pred_content = a_pred_content.replace("{", "").replace("is_answerable", "").replace('answer_content', "").replace( 106 | 'answer_content', "").replace('True', "") 107 | a_pred_content = a_pred_content.replace('False', "").replace('}', "").replace('answer_content', "").replace(':', "").replace(',', "") 108 | 109 | a_gold_content = " ".join([word for word in a_gold_content.split() if word != "''"]) 110 | a_pred_content = " ".join([word for word in a_pred_content.split() if word != "''"]) 111 | 112 | gold_toks = get_tokens(a_gold_content) 113 | pred_toks = get_tokens(a_pred_content) 114 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 115 | num_same = sum(common.values()) 116 | if len(gold_toks) == 0 or len(pred_toks) == 0: 117 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 118 | return int(gold_toks == pred_toks) 119 | if num_same == 0: 120 | return 0 121 | precision = 1.0 * num_same / len(pred_toks) 122 | recall = 1.0 * num_same / len(gold_toks) 123 | f1 = (2 * precision * recall) / (precision + recall) 124 | return f1 125 | 126 | 127 | def compute_f1_with_format(a_gold, a_pred): 128 | # a = {'is_answerable': True, 'answer_content': 'Arna Selznick'} 129 | content_gt = a_gold.replace("{", "").replace("is_answerable", "").replace('answer_content', "").replace( 130 | 'answer_content', "").replace('True', "") 131 | content_gt = content_gt.replace('False', "").replace('}', "").replace('answer_content', "").replace(':', "") 132 | 133 | content_pred = a_pred.replace("{", "").replace("is_answerable", "").replace('answer_content', "").replace( 134 | 'answer_content', "").replace('True', "") 135 | content_pred = content_pred.replace('False', "").replace('}', "").replace('answer_content', "").replace(':', "") 136 | 137 | content_pred = " ".join([word for word in content_pred.split() if word != "''"]) 138 | content_gt = " ".join([word for word in content_gt.split() if word != "''"]) 139 | 140 | a_gold_u = " ".join([word for word in a_gold.split() if word not in content_gt.split()]) 141 | a_pred_u = " ".join([word for word in a_pred.split() if word not in content_pred.split()]) 142 | 143 | for www in content_gt.split(" "): 144 | a_gold_u = a_gold_u.replace(www, "") 145 | 146 | for www in content_pred.split(" "): 147 | a_pred_u = a_pred_u.replace(www, "") 148 | 149 | a_gold_u = a_gold_u.replace('True', "").replace('False', "") 150 | a_pred_u = a_pred_u.replace('True', "").replace('False', "") 151 | 152 | a_gold_u = " ".join([word for word in a_gold_u.split() if word != "''"]) 153 | a_pred_u = " ".join([word for word in a_pred_u.split() if word != "''"]) 154 | 155 | gold_toks = get_tokens(a_gold_u) 156 | pred_toks = get_tokens(a_pred_u) 157 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 158 | num_same = sum(common.values()) 159 | if len(gold_toks) == 0 or len(pred_toks) == 0: 160 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 161 | return int(gold_toks == pred_toks) 162 | if num_same == 0: 163 | return 0 164 | precision = 1.0 * num_same / len(pred_toks) 165 | recall = 1.0 * num_same / len(gold_toks) 166 | f1 = (2 * precision * recall) / (precision + recall) 167 | return f1 168 | 169 | 170 | def compute_f1_is_answerable(a_gold, a_pred): 171 | # a = {'is_answerable': True, 'answer_content': 'Arna Selznick'} 172 | content_gt = a_gold.replace("{", "").replace("is_answerable", "").replace('answer_content', "").replace( 173 | 'answer_content', "").replace('True', "") 174 | content_gt = content_gt.replace('False', "").replace('}', "").replace('answer_content', "").replace(':', "") 175 | 176 | content_pred = a_pred.replace("{", "").replace("is_answerable", "").replace('answer_content', "").replace( 177 | 'answer_content', "").replace('True', "") 178 | content_pred = content_pred.replace('False', "").replace('}', "").replace('answer_content', "").replace(':', "") 179 | 180 | content_gt = " ".join([word for word in content_gt.split() if word != "''"]) 181 | content_pred = " ".join([word for word in content_pred.split() if word != "''"]) 182 | 183 | a_gold_u = " ".join([word for word in a_gold.split() if word not in content_gt.split()]) 184 | a_pred_u = " ".join([word for word in a_pred.split() if word not in content_pred.split()]) 185 | 186 | for www in content_gt.split(" "): 187 | a_gold_u = a_gold_u.replace(www, "") 188 | 189 | for www in content_pred.split(" "): 190 | a_pred_u = a_pred_u.replace(www, "") 191 | 192 | a_gold_u = a_gold_u.replace("{", "").replace("is_answerable", "").replace( 193 | 'answer_content', "").replace('}', "").replace(':', "").replace(',', "") 194 | a_pred_u = a_pred_u.replace("{", "").replace("is_answerable", "").replace( 195 | 'answer_content', "").replace('}', "").replace(':', "").replace(',', "") 196 | 197 | a_gold_u = " ".join([word for word in a_gold_u.split() if word != "''"]) 198 | a_pred_u = " ".join([word for word in a_pred_u.split() if word != "''"]) 199 | 200 | gold_toks = get_tokens(a_gold_u) 201 | pred_toks = get_tokens(a_pred_u) 202 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 203 | num_same = sum(common.values()) 204 | if len(gold_toks) == 0 or len(pred_toks) == 0: 205 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 206 | return int(gold_toks == pred_toks) 207 | if num_same == 0: 208 | return 0 209 | precision = 1.0 * num_same / len(pred_toks) 210 | recall = 1.0 * num_same / len(gold_toks) 211 | f1 = (2 * precision * recall) / (precision + recall) 212 | return f1 213 | 214 | 215 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 216 | scores_for_ground_truths = [] 217 | for ground_truth in ground_truths: 218 | score = metric_fn(prediction, ground_truth) 219 | scores_for_ground_truths.append(score) 220 | return max(scores_for_ground_truths) 221 | 222 | 223 | class AnswerMetric(Metric): 224 | def __init__(self) -> None: 225 | # self._total_em = 0.0 226 | # self._total_f1 = 0.0 227 | self._count = 0 228 | self._em = [] 229 | self._f1 = [] 230 | self._f1_content = [] 231 | self._f1_format = [] 232 | self._f1_is_answerable = [] 233 | def __call__( 234 | self, 235 | predicted_answer: str, 236 | ground_truth_answers: List[str], 237 | ): 238 | 239 | exact_scores = metric_max_over_ground_truths( 240 | compute_exact, predicted_answer, ground_truth_answers 241 | ) 242 | self._em.append(exact_scores) 243 | f1_scores = metric_max_over_ground_truths( 244 | compute_f1, predicted_answer, ground_truth_answers 245 | ) 246 | self._f1.append(f1_scores) 247 | f1_scores_with_format = metric_max_over_ground_truths( 248 | compute_f1_with_format, predicted_answer, ground_truth_answers 249 | ) 250 | 251 | self._f1_format.append(f1_scores_with_format) 252 | 253 | f1_scores_with_content = metric_max_over_ground_truths( 254 | compute_f1_with_content, predicted_answer, ground_truth_answers 255 | ) 256 | self._f1_content.append(f1_scores_with_content) 257 | 258 | f1_scores_is_answerable = metric_max_over_ground_truths( 259 | compute_f1_is_answerable, predicted_answer, ground_truth_answers 260 | ) 261 | self._f1_is_answerable.append(f1_scores_is_answerable) 262 | 263 | 264 | 265 | # self._total_em += int(exact_scores) 266 | # self._total_f1 += f1_scores 267 | self._count += 1 268 | 269 | def get_metric(self, reset: bool = False) -> Tuple[float, float, list, list]: 270 | exact_match = sum(self._em) / self._count if self._count > 0 else 0 271 | f1_score = sum(self._f1) / self._count if self._count > 0 else 0 272 | f1_score_with_content = sum(self._f1_content) / self._count if self._count > 0 else 0 273 | f1_score_with_format = sum(self._f1_format) / self._count if self._count > 0 else 0 274 | f1_score_is_answerable = sum(self._f1_is_answerable) / self._count if self._count > 0 else 0 275 | if reset: 276 | self.reset() 277 | return exact_match, f1_score,f1_score_with_content, f1_score_with_format, f1_score_is_answerable, self._em, self._f1, self._f1_content, self._f1_format, self._f1_is_answerable 278 | 279 | def reset(self): 280 | # self._total_em = 0.0 281 | self._em = [] 282 | # self._total_f1 = 0.0 283 | self._f1 = [] 284 | self._count = 0 285 | 286 | 287 | if __name__ == "__main__": 288 | def compute_f1(a_gold, a_pred): 289 | gold_toks = get_tokens(a_gold) 290 | pred_toks = get_tokens(a_pred) 291 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 292 | num_same = sum(common.values()) 293 | if len(gold_toks) == 0 or len(pred_toks) == 0: 294 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 295 | return int(gold_toks == pred_toks) 296 | if num_same == 0: 297 | return 0 298 | precision = 1.0 * num_same / len(pred_toks) 299 | recall = 1.0 * num_same / len(gold_toks) 300 | f1 = (2 * precision * recall) / (precision + recall) 301 | return f1 302 | 303 | a_pred = "Tian Yunzhang" 304 | a_gold = "Tianjin" 305 | print(compute_f1(a_gold= a_gold, a_pred = a_pred)) -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/group.py: -------------------------------------------------------------------------------- 1 | """ 2 | Abstract class to group MusiQue_metrics together. 3 | """ 4 | from typing import Dict 5 | 6 | from MusiQue_metrics.metric import Metric 7 | 8 | 9 | class GroupMetric(Metric): 10 | """ 11 | Abstract class to group MusiQue_metrics together. 12 | """ 13 | 14 | def __init__(self) -> None: 15 | self.reset() 16 | 17 | def compute_question_scores(self, group) -> Dict[str, float]: 18 | raise NotImplementedError 19 | 20 | def reset(self) -> None: 21 | raise NotImplementedError 22 | 23 | def get_metric(self, reset: bool = False) -> Dict[str, float]: 24 | 25 | total_scores = {"f1": 0.0, "em": 0.0, "suff": 0.0} 26 | for question_id, question_group in self.prediction_store.items(): 27 | question_scores = self.compute_question_scores(question_group) 28 | # self.score_store[question_id] = question_scores 29 | for key, value in question_scores.items(): 30 | total_scores[key] += value 31 | dataset_scores = { 32 | name: total_score / len(self.prediction_store) 33 | if len(self.prediction_store) > 0 34 | else 0.0 35 | for name, total_score in total_scores.items() 36 | } 37 | 38 | if reset: 39 | self.reset() 40 | 41 | return dataset_scores 42 | -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/group_answer_sufficiency.py: -------------------------------------------------------------------------------- 1 | """ 2 | Joint/grouped score of Answer and Sufficiency. 3 | """ 4 | from typing import List, Dict, Union 5 | from dataclasses import dataclass, field 6 | from collections import defaultdict 7 | from copy import deepcopy 8 | 9 | from MusiQue_metrics.group import GroupMetric 10 | from MusiQue_metrics.answer import AnswerMetric 11 | 12 | 13 | @dataclass 14 | class GoldPredictionInstance: 15 | gold_answers: str = None 16 | predicted_answer: str = None 17 | 18 | gold_sufficiencies: List = field(default_factory=lambda: deepcopy([])) 19 | predicted_sufficiencies: List = field(default_factory=lambda: deepcopy([])) 20 | 21 | 22 | class GroupAnswerSufficiencyMetric(GroupMetric): 23 | def __init__(self) -> None: 24 | self.prediction_store = defaultdict(GoldPredictionInstance) 25 | self.answer_metric = AnswerMetric() 26 | 27 | def compute_question_scores( 28 | self, group: GoldPredictionInstance 29 | ) -> Dict[str, float]: 30 | 31 | # Call it only when reset=True 32 | assert group.gold_answers is not None 33 | assert group.predicted_answer is not None 34 | assert len(group.predicted_sufficiencies) == 2 35 | 36 | assert isinstance(group.gold_answers, list) 37 | self.answer_metric(group.predicted_answer, group.gold_answers) 38 | ans_em, ans_f1 = self.answer_metric.get_metric(reset=True) 39 | 40 | sufficiency_score = group.predicted_sufficiencies == group.gold_sufficiencies 41 | ans_f1 = ans_f1 if sufficiency_score else 0.0 42 | ans_em = ans_em if sufficiency_score else 0.0 43 | sufficiency_score = float(sufficiency_score) 44 | 45 | question_scores = {"f1": ans_f1, "em": ans_em, "suff": sufficiency_score} 46 | return question_scores 47 | 48 | def __call__( 49 | self, 50 | predicted_answer: str, 51 | gold_answers: str, 52 | predicted_sufficiency: int, 53 | gold_sufficiency: int, 54 | question_id: Union[int, str], 55 | ) -> None: 56 | 57 | question_id = str(question_id) 58 | 59 | if gold_sufficiency == 1: 60 | self.prediction_store[question_id].predicted_answer = predicted_answer 61 | self.prediction_store[question_id].gold_answers = gold_answers 62 | 63 | self.prediction_store[question_id].predicted_sufficiencies.append( 64 | predicted_sufficiency 65 | ) 66 | self.prediction_store[question_id].gold_sufficiencies.append(gold_sufficiency) 67 | 68 | def reset(self): 69 | self.prediction_store = defaultdict(GoldPredictionInstance) 70 | -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/group_support_sufficiency.py: -------------------------------------------------------------------------------- 1 | """ 2 | Joint/grouped score of Support and Sufficiency. 3 | """ 4 | from typing import List, Dict, Union 5 | from dataclasses import dataclass, field 6 | from collections import defaultdict 7 | from copy import deepcopy 8 | 9 | from MusiQue_metrics.group import GroupMetric 10 | from MusiQue_metrics.support import SupportMetric 11 | 12 | 13 | @dataclass 14 | class GoldPredictionInstance: 15 | gold_supporting_facts: List = field(default_factory=lambda: deepcopy([])) 16 | predicted_supporting_facts: List = field(default_factory=lambda: deepcopy([])) 17 | 18 | gold_sufficiencies: List = field(default_factory=lambda: deepcopy([])) 19 | predicted_sufficiencies: List = field(default_factory=lambda: deepcopy([])) 20 | 21 | 22 | class GroupSupportSufficiencyMetric(GroupMetric): 23 | def __init__(self) -> None: 24 | self.prediction_store = defaultdict(GoldPredictionInstance) 25 | self.support_metric = SupportMetric() 26 | 27 | def compute_question_scores( 28 | self, group: GoldPredictionInstance 29 | ) -> Dict[str, float]: 30 | 31 | # Call it only when reset=True 32 | assert group.gold_supporting_facts is not None 33 | assert group.predicted_supporting_facts is not None 34 | assert len(group.predicted_sufficiencies) == 2 35 | 36 | self.support_metric( 37 | group.predicted_supporting_facts, group.gold_supporting_facts 38 | ) 39 | sp_em, sp_f1 = self.support_metric.get_metric(reset=True) 40 | 41 | sufficiency_score = group.predicted_sufficiencies == group.gold_sufficiencies 42 | sp_f1 = sp_f1 if sufficiency_score else 0.0 43 | sp_em = sp_em if sufficiency_score else 0.0 44 | sufficiency_score = float(sufficiency_score) 45 | 46 | question_scores = {"f1": sp_f1, "em": sp_em, "suff": sufficiency_score} 47 | return question_scores 48 | 49 | def __call__( 50 | self, 51 | predicted_supporting_facts: List, 52 | gold_supporting_facts: List, 53 | predicted_sufficiency: int, 54 | gold_sufficiency: int, 55 | question_id: Union[int, str], 56 | ) -> None: 57 | 58 | question_id = str(question_id) 59 | 60 | if gold_sufficiency == 1: 61 | self.prediction_store[ 62 | question_id 63 | ].gold_supporting_facts = gold_supporting_facts 64 | self.prediction_store[ 65 | question_id 66 | ].predicted_supporting_facts = predicted_supporting_facts 67 | 68 | self.prediction_store[question_id].predicted_sufficiencies.append( 69 | predicted_sufficiency 70 | ) 71 | self.prediction_store[question_id].gold_sufficiencies.append(gold_sufficiency) 72 | 73 | def reset(self): 74 | self.prediction_store = defaultdict(GoldPredictionInstance) 75 | -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | An abstract class representing a metric which can be accumulated. 3 | """ 4 | from typing import Any, Dict 5 | 6 | 7 | class Metric: 8 | """ 9 | An abstract class representing a metric which can be accumulated. 10 | """ 11 | 12 | def __call__(self, predictions: Any, gold_labels: Any): 13 | raise NotImplementedError 14 | 15 | def get_metric(self, reset: bool) -> Dict[str, Any]: 16 | """ 17 | Compute and return the metric. Optionally also call `self.reset`. 18 | """ 19 | raise NotImplementedError 20 | 21 | def reset(self) -> None: 22 | """ 23 | Reset any accumulators or internal state. 24 | """ 25 | raise NotImplementedError 26 | -------------------------------------------------------------------------------- /evaluation_classes/MusiQue_metrics/support.py: -------------------------------------------------------------------------------- 1 | """ 2 | Support metric -- mostly taken directly from hotpotqa 3 | """ 4 | from typing import Tuple, List 5 | 6 | from mdqa.MusiQue_metrics.metric import Metric 7 | 8 | 9 | class SupportMetric(Metric): 10 | """ 11 | SupportMetric: Em and F1 (Similar to HotpotQA Sp metric) 12 | """ 13 | 14 | def __init__(self) -> None: 15 | self._total_em = 0.0 16 | self._total_f1 = 0.0 17 | self._total_precision = 0.0 18 | self._total_recall = 0.0 19 | self._count = 0 20 | 21 | def __call__(self, predicted_support_idxs: List[int], gold_support_idxs: List[int]): 22 | 23 | # Taken from hotpot_eval 24 | cur_sp_pred = set(map(int, predicted_support_idxs)) 25 | gold_sp_pred = set(map(int, gold_support_idxs)) 26 | tp, fp, fn = 0, 0, 0 27 | for e in cur_sp_pred: 28 | if e in gold_sp_pred: 29 | tp += 1 30 | else: 31 | fp += 1 32 | for e in gold_sp_pred: 33 | if e not in cur_sp_pred: 34 | fn += 1 35 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 36 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 37 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 38 | em = 1.0 if fp + fn == 0 else 0.0 39 | 40 | # In case everything is empty, set both f1, em to be 1.0. 41 | # Without this change, em gets 1 and f1 gets 0 42 | if not cur_sp_pred and not gold_sp_pred: 43 | f1, em = 1.0, 1.0 44 | f1, em = 1.0, 1.0 45 | 46 | self._total_em += float(em) 47 | self._total_f1 += f1 48 | self._total_precision += prec 49 | self._total_recall += recall 50 | self._count += 1 51 | 52 | def get_metric(self, reset: bool = False) -> Tuple[float, float]: 53 | """ 54 | Returns 55 | ------- 56 | Average exact match and F1 score (in that order). 57 | """ 58 | exact_match = self._total_em / self._count if self._count > 0 else 0 59 | f1_score = self._total_f1 / self._count if self._count > 0 else 0 60 | # precision_score = self._total_precision / self._count if self._count > 0 else 0 61 | # recall_score = self._total_recall / self._count if self._count > 0 else 0 62 | 63 | if reset: 64 | self.reset() 65 | return exact_match, f1_score 66 | 67 | def reset(self): 68 | self._total_em = 0.0 69 | self._total_f1 = 0.0 70 | self._total_precision = 0.0 71 | self._total_recall = 0.0 72 | self._count = 0 73 | -------------------------------------------------------------------------------- /evaluation_classes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/evaluation_classes/__init__.py -------------------------------------------------------------------------------- /evaluation_classes/coreference.py: -------------------------------------------------------------------------------- 1 | 2 | import subprocess 3 | import numpy as np 4 | from evaluation_classes.eval_base_class import Eval 5 | import re 6 | 7 | 8 | class Coref(Eval): 9 | 10 | def __init__(self, id_key, predictions_dir, out_path): 11 | self.correlations = [] 12 | super().__init__(id_key, predictions_dir, out_path) 13 | 14 | def _postprocess(self, predictions): 15 | all_processed = {} 16 | correct_format = [] 17 | for sample in predictions: 18 | sample_id = sample[self.id_key] 19 | only_response = self.get_only_response(sample) 20 | all_lists = [] 21 | suffix_truncate = 0 22 | for it in re.finditer(r'\[[\d, ]+\]', only_response): 23 | reg = it.span() 24 | all_lists.append(eval(only_response[reg[0]:reg[1]])) 25 | suffix_truncate = max(reg[1], suffix_truncate) 26 | only_response = only_response[suffix_truncate:] 27 | last_open_list = re.search(r'\[[\d, ]+', only_response) 28 | if last_open_list is not None: 29 | reg = last_open_list.regs[0] 30 | all_lists.append(eval(only_response[reg[0]:reg[1]]+"]")) 31 | correct_format.append(len(all_lists)>0) 32 | mentions_to_clusters = {} 33 | without_singletons = [cluster for cluster in all_lists if len(cluster) > 1] 34 | for i, cluster in enumerate(without_singletons): 35 | for mention in cluster: 36 | mentions_to_clusters[mention] = i 37 | all_processed[sample_id] = mentions_to_clusters 38 | self.correlations.append([np.mean(correct_format)]) 39 | return all_processed 40 | 41 | def process_conll_out(self, out): 42 | all_scores = {} 43 | metrics = out.split("METRIC")[1:] 44 | splitted_metrics = [met.split('\n') for met in metrics] 45 | names = [re.sub('\W', '', met[0]) for met in splitted_metrics] 46 | coref_lines = [[line for line in met if 'Coreference' in line] for met in splitted_metrics] 47 | for i, line in enumerate(coref_lines): 48 | if len(line) > 1: 49 | index = splitted_metrics[i].index(line[0]) 50 | for row in splitted_metrics[i][index+1:-1]: 51 | if row.startswith("---"): 52 | continue 53 | metric_name = row[:row.find(':')] 54 | scores = row[row.find(':')+2:].split("\t") 55 | for score in scores: 56 | name = re.search(r'[a-zA-Z]+1?: ', score).group(0)[:-2] 57 | res = eval(re.search(r'\d+(\.\d+)?%', score).group(0)[:-1]) 58 | all_scores[f"{names[i]}_{metric_name}_{name}"] = res 59 | else: 60 | scores = re.sub('Coreference: ', '', line[0]).split("\t") 61 | for score in scores: 62 | name = re.search(r'[a-zA-Z]+1?: ', score).group(0)[:-2] 63 | res = eval(re.search(r'\d+(\.\d+)?%', score).group(0)[:-1]) 64 | all_scores[f"{names[i]}_{name}"] = res 65 | conll_f1 = np.average([all_scores[f"{name}_F1"] for name in ["muc", "bcub", "ceafe"]]) 66 | self.correlations[-1].append(conll_f1) 67 | output = {"conll_F1": conll_f1} 68 | return output 69 | 70 | def _evaluate(self, predictions, model_name, sample_index): 71 | 72 | predictions_processed = self._postprocess(predictions) 73 | conll_pred = "/tmp/predictions.conll" 74 | out_file = open(conll_pred, "w") 75 | for sample_id in predictions_processed: 76 | pred = predictions_processed[sample_id] 77 | out_file.write(f"#begin document id={sample_id}\n") 78 | for mention_id in sorted(pred): 79 | out_file.write(f"{mention_id}\t({pred[mention_id]})\n") 80 | out_file.write("#end document\n") 81 | out_file.close() 82 | conll_gold = "/tmp/gold.conll" 83 | target_file = open(conll_gold, "w") 84 | for i, sample in enumerate(predictions, start=1): 85 | target = sample["targets"] 86 | sample_id = sample[self.id_key] 87 | mention_to_cluster = {} 88 | for cluster_id, cluster in enumerate(target): 89 | for mention_id in cluster: 90 | mention_to_cluster[mention_id] = cluster_id 91 | target_file.write(f"#begin document id={sample_id}\n") 92 | for mention_id in sorted(mention_to_cluster): 93 | target_file.write(f"{mention_id}\t({mention_to_cluster[mention_id]})\n") 94 | target_file.write("#end document\n") 95 | r = subprocess.run(["perl", "/Users/gililior/research/py_repos/reference-coreference-scorers/scorer.pl", "all", conll_gold, conll_pred], capture_output=True) 96 | processed = self.process_conll_out(r.stdout.decode()) 97 | return processed 98 | 99 | 100 | -------------------------------------------------------------------------------- /evaluation_classes/eval_base_class.py: -------------------------------------------------------------------------------- 1 | 2 | import tiktoken 3 | from glob import glob 4 | import json 5 | import os 6 | MODELS_NAME_MAPPING = { 7 | "Meta-Llama-3-8B-Instruct": "Llama3-8B", 8 | "Meta-Llama-3-70B-Instruct": "Llama3-70B", 9 | "gemma-1.1-7b-it": "Gemma1.1-7B", 10 | "gemma-1.1-2b-it": "Gemma1.1-2B", 11 | "Mistral-7B-Instruct-v0.2": "Mistral-7B", 12 | "Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B", 13 | "Mixtral-8x22B-Instruct-v0.1": "Mixtral-8x22B", 14 | 'Qwen2-72B-Instruct':"Qwen/Qwen2-72B-Instruct", 15 | "Qwen2.5-72B-Instruct-Turbo": "Qwen2.5-72B-Instruct-Turbo", 16 | "gemma-2-27b-it": "gemma-2-27b-it", 17 | "gemma-2b-it": "gemma-2b-it", 18 | "Qwen/Qwen2-7B": "Qwen/Qwen2-7B", 19 | "Qwen2-7B": "Qwen2-7B", 20 | "google/gemma-2-9b": "google/gemma-2-9b", 21 | "gemma-2-9b-it": "google/gemma-2-9b-it", 22 | "Qwen2-7B-Instruct": "Qwen/Qwen2-7B-Instruct", 23 | "meta-llama/Llama-3.1-8B": "meta-llama/Llama-3.1-8B", 24 | "Meta-Llama-3.1-70B-Instruct-Turbo": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" 25 | 26 | } 27 | 28 | import numpy as np 29 | 30 | class Eval: 31 | def __init__(self, id_key, predictions_dir, out_path): 32 | self.id_key = id_key 33 | self.predictions_dir = predictions_dir 34 | self.out_path = out_path 35 | 36 | def _evaluate(self, data, model_name, sample_index): 37 | raise NotImplementedError("This method needs to be implemented by subclasses") 38 | 39 | def eval_all_in_dir(self): 40 | print("results will be saved in:", self.out_path) 41 | all_results = {} 42 | sample_lengths = {} 43 | if os.path.exists(self.out_path): 44 | with open(self.out_path, 'rt') as f: 45 | existing = json.load(f) 46 | all_results = existing["models"] 47 | sample_lengths = existing["sample_lengths"] 48 | encoding = tiktoken.encoding_for_model("gpt-4") 49 | for f_name in glob(f'{self.predictions_dir}/*.json'): 50 | sample_name = f_name.replace(self.predictions_dir, "").replace(".json", "") 51 | model = sample_name[:-2].split(os.sep)[-1] 52 | model_name = MODELS_NAME_MAPPING[model] 53 | sample_index = int(sample_name[-1]) 54 | print("Evaluating model:", model_name, "sample:", sample_index) 55 | if model_name in all_results and sample_index in all_results[model_name]["run_index"]: 56 | print("skipping", model_name, sample_index) 57 | continue 58 | if model_name not in all_results: 59 | all_results[model_name] = {"scores": [], "run_index": [], "ids": []} 60 | all_results[model_name]["run_index"].append(sample_index) 61 | 62 | with open(f_name, 'rt') as f: 63 | predictions = json.load(f) 64 | 65 | current_ids = [] 66 | for pred in predictions: 67 | id_sample = str(pred[self.id_key]) 68 | current_ids.append(id_sample) 69 | if id_sample not in sample_lengths: 70 | length = len(encoding.encode(pred["final_msgs"][0]['content'])) 71 | sample_lengths[id_sample] = length 72 | all_results[model_name]["ids"].append(current_ids) 73 | results = self._evaluate(predictions, model_name, sample_index) 74 | all_results[model_name]["scores"].append(results) 75 | out_dict = {"sample_lengths": sample_lengths, "models": all_results} 76 | with open(self.out_path, 'wt') as f: 77 | json.dump(out_dict, f) 78 | # print(np.corrcoef(np.array(self.correlations).T)[0,1]) 79 | 80 | @staticmethod 81 | def get_only_response(prediction): 82 | pred = prediction["prediction"] 83 | if isinstance(pred, list): 84 | pred = str(pred[1]) 85 | 86 | if pred is None: 87 | return pred 88 | if '[/INST]' in pred: 89 | pred = pred[pred.find("[/INST]")+len("[/INST]"):].strip() 90 | elif "So, the answer is:" in pred: 91 | pred = pred[pred.rfind("So, the answer is:") + len("So, the answer is:"):].strip() 92 | if pred == "": 93 | # print(prediction["prediction"]) 94 | pred = "No answer" 95 | elif "Aspect-based summary:" in pred: 96 | pred = pred[pred.rfind("Aspect-based summary:") + len("Aspect-based summary:"):].strip() 97 | elif pred.rfind("*Answer*:") != -1: 98 | pred = pred[pred.rfind("*Answer*:") + len("*Answer*:"):].strip() 99 | 100 | return pred 101 | -------------------------------------------------------------------------------- /evaluation_classes/fuse_reviews.py: -------------------------------------------------------------------------------- 1 | 2 | # import fic_evaluation 3 | from evaluation_classes.eval_base_class import Eval 4 | import torch 5 | import gc 6 | 7 | 8 | class FuseReviews(Eval): 9 | def __init__(self, id_key, predictions_dir, out_path): 10 | super().__init__(id_key, predictions_dir, out_path) 11 | 12 | @staticmethod 13 | def get_faithfulness(predictions, only_responses): 14 | faithfulness_metric = fic_evaluation.HighlightsFaithfulnessEvaluator() 15 | faithfulness_results = faithfulness_metric.evaluate( 16 | concatenated_highlights=[sample["highlights_concat"] for sample in predictions], 17 | predictions=only_responses) 18 | return faithfulness_results 19 | 20 | @staticmethod 21 | def get_coverage(predictions, only_responses): 22 | coverage_metric = fic_evaluation.HighlightsCoverageEvaluator() 23 | coverage_results = coverage_metric.evaluate( 24 | review_side_alignments=[sample["review_side_alignments"] for sample in predictions], 25 | predictions=only_responses) 26 | return coverage_results 27 | 28 | def _evaluate(self, predictions, model_name, sample_index): 29 | only_responses = [self.get_only_response(s) for s in predictions] 30 | try: 31 | faithfulness_results = self.get_faithfulness(predictions, only_responses) 32 | except Exception as e: 33 | # print(predictions) 34 | raise e 35 | coverage_results = self.get_coverage(predictions, only_responses) 36 | out = faithfulness_results | coverage_results 37 | return out 38 | 39 | -------------------------------------------------------------------------------- /evaluation_classes/question_answering.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from evaluation_classes.eval_base_class import Eval 4 | from evaluation_classes.MusiQue_metrics.answer import AnswerMetric 5 | import numpy as np 6 | import json 7 | 8 | class QA(Eval): 9 | 10 | def __init__(self, id_key, predictions_dir, out_path): 11 | self.correlations = [] 12 | super().__init__(id_key, predictions_dir, out_path) 13 | 14 | def _evaluate(self, predictions, model_name, sample_index): 15 | answer_metric = AnswerMetric() 16 | follow_format = [] 17 | for sample in predictions: 18 | gt = sample["target"] 19 | # print("pred", sample) 20 | pred = self.postprocess(sample) 21 | follow_format.append(type(pred) is dict) 22 | 23 | # gt = gt.replace("}",", 'is_answerable': True}" ) 24 | gt["is_answerable"] = True 25 | # gt, pred = self.parse(gt, pred) 26 | answer_metric(predicted_answer=str(gt), ground_truth_answers=[str(pred)]) 27 | 28 | metric = answer_metric.get_metric() 29 | self.correlations.append([np.mean(follow_format), np.mean(metric[3])]) 30 | 31 | metrics = {"all_f1": metric[6], 32 | "all_f1_content": metric[7], 33 | "all_f1_format": metric[8], 34 | "all_f1_is_answerable": metric[9], 35 | "mean_f1": metric[1], 36 | "mean_f1_content": metric[2], 37 | "mean_f1_format": metric[3], 38 | "mean_f1_is_answerable": metric[4], 39 | } 40 | print("mean_f1", metrics["mean_f1"]) 41 | print("mean_f1_content", metrics["mean_f1_content"]) 42 | print("mean_f1_format", metrics["mean_f1_format"]) 43 | print("mean_f1_is_answerable", metrics["mean_f1_is_answerable"]) 44 | # exit() 45 | return metrics 46 | 47 | def parse(self, ground_truth_answer, predicted_answer): 48 | gt = self._extract_answer_from_dict(ground_truth_answer) 49 | if type(predicted_answer) is dict: 50 | pred = self._extract_answer_from_dict(predicted_answer) 51 | else: 52 | pred = predicted_answer 53 | return gt, pred 54 | 55 | @staticmethod 56 | def _extract_answer_from_dict(answer_dict): 57 | if answer_dict["is_answerable"]: 58 | answer = answer_dict["answer_content"] 59 | else: 60 | answer = answer_dict["is_answerable"] 61 | return answer 62 | 63 | def postprocess(self, pred): 64 | # print(pred) 65 | only_response = self.get_only_response(pred) 66 | if only_response is None: 67 | return only_response 68 | answer_dict = re.search(r'\{.*\}', only_response) 69 | if answer_dict is None: 70 | return only_response 71 | 72 | str_dict = answer_dict.group(0) 73 | str_dict = str_dict.replace("'s", "\\'s").replace("'t", "\\'t").replace("s' ", "s\\' ") 74 | str_dict = str_dict.replace("\\\\_", "_").replace("\\_", "_") 75 | try: 76 | answer = eval(str_dict) 77 | except Exception as e: 78 | try: 79 | str_dict = str_dict.replace("}", "'}") 80 | answer = eval(str_dict) 81 | except Exception as e: 82 | answer = only_response 83 | return answer 84 | 85 | 86 | -------------------------------------------------------------------------------- /evaluation_classes/summarization.py: -------------------------------------------------------------------------------- 1 | 2 | from evaluation_classes.eval_base_class import Eval 3 | from rouge_score import rouge_scorer 4 | 5 | 6 | class Summarization(Eval): 7 | def __init__(self, id_key, predictions_dir, out_path): 8 | self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) 9 | super().__init__(id_key, predictions_dir, out_path) 10 | 11 | def _evaluate(self, predictions, model_name, sample_index): 12 | predictions_only = [self.get_only_response(sample) for sample in predictions] 13 | target_only = [sample["target"] for sample in predictions] 14 | rouge_scores = [self.rouge.score(pred, target) for pred, target in zip(predictions_only, target_only)] 15 | metrics = rouge_scores[0].keys() 16 | results = {metric: [score[metric].fmeasure for score in rouge_scores] for metric in metrics} 17 | return results 18 | 19 | 20 | -------------------------------------------------------------------------------- /files/configuration/predict.json: -------------------------------------------------------------------------------- 1 | { 2 | "out_dir": "Prediction Path", 3 | "run_name": "base", 4 | "datasets_pickle_path": "Dataset Path", 5 | "override": false, 6 | "truncation_strategy": "max", 7 | "max_num_tokens": null, 8 | "batch_size": 2, 9 | "temperature": 0.8 10 | } 11 | -------------------------------------------------------------------------------- /files/configuration/predict_c.json: -------------------------------------------------------------------------------- 1 | { 2 | "out_dir": "/cs/labs/tomhope/nirm/MusiQue/out_new/", 3 | "run_name": "base", 4 | "datasets_pickle_path": "/cs/labs/tomhope/nirm/MusiQue/pkl_files/datasets_baselineSet.pkl", 5 | "override": false, 6 | "truncation_strategy": "max", 7 | "max_num_tokens": null, 8 | "batch_size": 2, 9 | "temperature": 0.8 10 | } 11 | -------------------------------------------------------------------------------- /files/prompts/ECB.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompts": [ 3 | { 4 | "id": 0, 5 | "instructions": "In this task you are presented with multiple documents that are related to the same topic. In each document, each mention of an entity or event is enclosed with square brackets, followed by the mention's unique id. That is: '[some mention](id)'. Your task is to identify and cluster together mentions that are coreferring with each other, i.e., mentions that mark the same entity or event. output the list of clusters, where each cluster is a list of ids of coreferring mentions. For example, if you think that mentions (1) and (2) refer to the same entity, and mentions (3) and (4) refer to the same entity, but these two entities are different, you should output: [[1, 2], [3, 4]].", 6 | "comments": "" 7 | }, 8 | { 9 | "id": 1, 10 | "instructions" : "Given a set of documents related to the same topic, each containing entities or events marked with unique identifiers in the format '[some mention](id)', your task is to analyze these documents and identify mentions that refer to the same entity or event. Once identified, group these mentions into clusters, where each cluster contains the ids of mentions that are coreferential. Output the clusters in the following format: a list of lists, where each inner list contains the ids of coreferring mentions. For instance, if mentions (1) and (2) are coreferential, and mentions (3) and (4) are coreferential, but refer to different entities, the output should be: [[1, 2], [3, 4]]. Please process the input and produce the correct clusters based on coreference.", 11 | "comments": "" 12 | }, 13 | { 14 | "id": 2, 15 | "instructions" : "Your task is to process a set of documents, each discussing a similar topic and containing marked mentions of entities or events in the format '[some mention](id)'. Identify mentions that refer to the same entity or event and group these ids into clusters. Each cluster should consist of ids that belong to coreferential mentions. Output these clusters in a nested list format, for example, if mentions tagged as (1) and (2) are coreferential, and (3) and (4) are as well, but separate from the first group, your output should look like: [[1, 2], [3, 4]].", 16 | "comments": "" 17 | }, 18 | { 19 | "id": 3, 20 | "instructions" : "Examine several documents linked by a common theme, noting that each mention of an entity or event within these documents is denoted by '[some mention](id)'. Your objective is to find and group ids of mentions that correspond to the same entity or event into clusters. The desired output format is a list of lists, where each sublist contains ids of mentions that are coreferential. For example, an appropriate output for coreferential mentions (1) and (2), and another pair (3) and (4), would be [[1, 2], [3, 4]].", 21 | "comments": "" 22 | }, 23 | { 24 | "id": 4, 25 | "instructions" : "Within multiple documents on a similar subject, each entity or event mention is bracketed and tagged with an id in the format '[some mention](id)'. Your role is to ascertain which mentions refer to the same entity or event, grouping these ids into clusters. Each cluster will be a list containing the ids of mentions that are coreferential. For instance, if you determine that (1) and (2), as well as (3) and (4), refer to distinct entities, you should produce the output as: [[1, 2], [3, 4]].", 26 | "comments": "" 27 | }, 28 | { 29 | "id": 6, 30 | "instructions" : "Analyze a series of documents that discuss the same topic where entities and events are noted as '[some mention](id)'. Group mentions that refer to the same entity or event into clusters by their ids. Each cluster should be presented as a list of these ids, such as [[1, 2], [3, 4]], indicating coreference between the mentions within each list.", 31 | "comments": "" 32 | }, 33 | { 34 | "id": 7, 35 | "instructions" : "Review a collection of documents with related topics, identifying and clustering ids of entity and event mentions, denoted as '[some mention](id)'. Mentions that are coreferential should be grouped together. Format your output as a series of lists, each containing ids of mentions that reference the same entity or event, e.g., [[1, 2], [3, 4]]", 36 | "comments": "" 37 | }, 38 | { 39 | "id": 8, 40 | "instructions" : "Given multiple documents related by a common theme, each containing marked mentions of entities or events as '[some mention](id)', identify which mentions refer to the same entity or event and cluster these mentions' ids together. Your objective is to find and group ids of mentions that correspond to the same entity or event into clusters. The desired output format is a list of lists, where each sublist contains ids of mentions that are coreferential. For example, an appropriate output for coreferential mentions (1) and (2), and another pair (3) and (4), would be [[1, 2], [3, 4]].", 41 | "comments": "" 42 | }, 43 | { 44 | "id": 9, 45 | "instructions" : "Given a collection of documents on the same topic, identify mentions of entities or events marked as '[some mention](id)'. Your task is to group ids of mentions that refer to the same entity or event into clusters. Output these clusters as lists of ids, each representing a set of coreferential mentions. For example, if you find that mentions (1) and (2) refer to the same entity, and (3) and (4) refer to another, output should be: [[1, 2], [3, 4]].", 46 | "comments": "" 47 | }, 48 | { 49 | "id": 10, 50 | "instructions" : "Analyze a set of documents that discuss a unified topic, where each entity or event is uniquely identified by '[some mention](id)'. Group together the ids of mentions that are coreferential. Present your findings as a list of lists, where each sublist contains ids of mentions referring to the same entity or event, e.g., [[1, 2], [3, 4]]. Ensure accuracy in identifying and clustering these mentions.", 51 | "comments": "" 52 | }, 53 | { 54 | "id": 11, 55 | "instructions" : "Review documents related by a common theme, each containing marked mentions of entities or events in the format '[some mention](id)'. Determine which mentions refer to the same entity or event and cluster their ids accordingly. Output these clusters as nested lists, where each list contains ids of coreferential mentions, such as [[1, 2], [3, 4]]. This task requires precise identification and clustering based on coreference.", 56 | "comments": "" 57 | }, 58 | { 59 | "id": 12, 60 | "instructions" : "Your challenge is to navigate through multiple documents related by a common topic, identifying mentions of entities or events enclosed as '[some mention](id)'. Group these mentions by ids when they refer to the same entity or event. Present your clusters as lists of ids that show coreference, for instance, [[1, 2], [3, 4]]. This task requires keen attention to detail and accuracy in identifying connections", 61 | "comments": "" 62 | }, 63 | { 64 | "id": 13, 65 | "instructions" : "Act as a data detective and delve into several documents, each marked by similar themes and containing entities or events tagged as '[some mention](id)'. Your mission is to uncover which mentions are coreferential and cluster their ids accordingly. Your findings should be reported as lists of ids grouped by coreference, like [[1, 2], [3, 4]].", 66 | "comments": "" 67 | }, 68 | { 69 | "id": 14, 70 | "instructions" : "Task: Analyze a collection of thematic documents to identify and cluster mentions of entities or events, formatted as '[some mention](id)', that refer to the same concept. Your output should clearly display clusters of ids representing coreferential mentions, such as [[1, 2], [3, 4]], demonstrating your ability to discern and link related information.", 71 | "comments": "" 72 | }, 73 | { 74 | "id": 15, 75 | "instructions" : "In this detailed analysis task, you are provided with multiple documents, each discussing a similar theme. Within these documents, mentions of entities or events are specifically highlighted and labeled with unique identifiers in the format 'some mention](id)'. Your primary objective is to meticulously identify which of these mentions refer to the same real-world entity or event, and then systematically group together their corresponding ids into clusters. Each cluster should exclusively contain ids of mentions that are coreferential. For clarity in your output, format your results into a list of these clusters, where each sublist represents a distinct group of coreferential mentions. For instance, if your analysis concludes that mentions (1) and (2) are about the same entity, and mentions (3) and (4) about another distinct entity, your output should be formatted as: [[1, 2], [3, 4]]. This task requires precise attention to detail and analytical rigor to ensure accuracy in the clustering process.", 76 | "comments": "" 77 | }, 78 | { 79 | "id": 16, 80 | "instructions" : "Examine several documents on a unified topic, noting each mention of an entity or event as '[some mention](id)'. Group mentions by ids that refer to the same entity or event. Output these clusters in a format of nested lists, with each list containing ids of coreferential mentions, like [[1, 2], [3, 4]].", 81 | "comments": "" 82 | }, 83 | { 84 | "id": 17, 85 | "instructions" : "In this comprehensive task, you are faced with an array of documents tied by a common thematic element. Each document features various mentions of entities or events, each enclosed within square brackets and followed by a unique identification number in the format '[some mention](id)'. It is your responsibility to sift through these mentions, discerning which ones are references to the same entity or event across different texts. Upon identifying these coreferential mentions, you are to organize and cluster their ids into coherent groups. The final output should be a structured list of these clusters, with each cluster formatted as a list containing ids of mentions that you have determined to be about the same entity or event. For example, if you deduce that mentions (1) and (2) discuss one entity, and mentions (3) and (4) another, your result should be presented as: [[1, 2], [3, 4]]. This task demands a high level of accuracy and a methodical approach to ensure that each cluster is correctly assembled based on coreference analysis.", 86 | "comments": "" 87 | }, 88 | { 89 | "id": 18, 90 | "instructions" : "Within a group of related documents, each mention of an entity or event is uniquely identified in the format '[some mention](id)'. Your objective is to examine these documents and cluster the ids of mentions that are coreferential, i.e., refer to the same entity or event. The output should be a list of lists, where each sublist contains the ids of mentions that are grouped together based on coreference. For example, if mentions (1) and (2) are about the same entity, and (3) and (4) are about another, then the output should look like: [[1, 2], [3, 4]].", 91 | "comments": "" 92 | }, 93 | { 94 | "id": 19, 95 | "instructions" : "Engage in a task where you need to sift through several documents that deal with similar topics, noting that each entity or event mention is tagged in the format '[some mention](id)'. Identify and group mentions that refer to the same entity or event into clusters by their ids. Produce an output that consists of lists, where each list is a cluster of ids representing coreferential mentions. For instance, if (1) and (2) are deemed coreferential, as are (3) and (4), then your output should be structured as: [[1, 2], [3, 4]].", 96 | "comments": "" 97 | }, 98 | { 99 | "id": 20, 100 | "instructions" : "Your task involves processing a set of documents on a common subject, where each document includes mentions of entities or events marked by '[some mention](id)'. Your role is to determine which mentions refer to the same entity or event and cluster their ids together. The expected output format is a series of nested lists, with each list containing ids of coreferential mentions. For example, if mentions (1) and (2) are about the same entity, and (3) and (4) are about another, organize your output as: [[1, 2], [3, 4]].", 101 | "comments": "" 102 | } 103 | ], 104 | "demonstrations": [ 105 | { 106 | "topic_id_train": 10, 107 | "documents": [ 108 | "Red Sox [extend](3) [offer](21) to Teixeira Boston reportedly [looks](57) to [ink](49) free - agent slugger for eight years", 109 | "Angels general manager Tony Reagins [confirmed](69) to the Los Angeles Times that he has [made](28) an eight-year [offer](18) to free agent first baseman Mark Teixeira .", 110 | "The Los Angeles Angels [made](42) an eight-year [offer](41) to first baseman Mark Teixeira during [winter meetings](19) in Las Vegas .", 111 | "Lynch [reported](0) the Red Sox team of owner John Henry , COO Larry Lucchino and General Manager Theo Epstein [offered](4) Teixeira , 28 , an eight - year deal worth $184 million or $23 million per season .", 112 | "According to this L.A. Times story ( which [credits](60) the team-owned flagship radio station ) , the Angels have [made](37) an eight-year [offer](45) to free agent first baseman Mark Teixeira .", 113 | "Red Sox owner John Henry has [said](8) , in so many words , that his team will not [sign](25) Mark Teixeira .", 114 | "The Angels have [offered](39) an eight-year contract to Mark Teixeira , general manager Tony Reagins [said](55) Friday night .", 115 | "Herald : Sox [Make](75) [Offer](7) To Teixeira ( 8 Years ? )", 116 | "Meanwhile , according to the Newsday , Teixeira still has an eight - year [offer](30) on the table from the Red Sox , [worth](51) roughly $22 million per season ." 117 | ], 118 | "target": [ 119 | [25], [75, 3], [28, 37, 42], [8], [57], [49], [0], [60], [4, 7, 30, 21], [19], [55, 69], [45, 39, 41, 18], [51] 120 | ] 121 | }, 122 | { 123 | "topic_id_train": 4, 124 | "documents": [ 125 | "Bettie Page , the 1950s pin-up model who helped [set the stage](36) for the 1960s sexual [revolution](21) , has [died](72) .", 126 | "Yesterday in Los Angeles , pin-up icon Bettie Page [succumbed](16) to complications from a [heart attack](61) suffered almost three weeks ago .", 127 | "Pinup icon Bettie Page [died](102) Thursday evening at a hospital in Los Angeles after having [suffered](88) a [heart attack](113) on Dec. 2 and [spending](99) time on life support , according to her official site .", 128 | "Esther Williams , who [swam](89) to [stardom](77) in the 1940's , [dies](94) at 91", 129 | "Esther Williams , swimming champion , actress , and pin - up icon , has [died](33) . She was 91 . Williams , 91 , [passed](14) away early Thursday in her [sleep](58) , [according to](25) her publicist ." 130 | ], 131 | "target": [ 132 | [25], [94, 14, 33], [77], [36], [58], [99], [88], [21], [89], [113, 61], [16, 102, 72] 133 | ] 134 | }, 135 | { 136 | "topic_id_train": 14, 137 | "documents": [ 138 | "A HUGE [fire](45) has almost totally [destroyed](80) a Waitrose store in Banstead .", 139 | "Residents [evacuated](58) from their homes after a huge [fire](65) at a Waitrose store in Banstead on Friday night have been [allowed](72) to [return](3) to their homes .", 140 | "Wellington supermarket [fire](63) A large [fire](22) has [destroyed](0) much of a supermarket in Somerset .", 141 | "Fire and police units are still at the scene of a [fire](79) which [gutted](46) a Waitrose supermarket in Surrey .", 142 | "A Waitrose store was [reduced](70) to ruins after a [blaze](60) being [treated](33) by police as `` potentially suspicious '' .", 143 | "A [fire](8) that [ripped](41) through a Waitrose store in Surrey is being [treated](52) as `` potentially suspicious '' , police [said](16) .", 144 | "Police are [treating](71) a [fire](15) which has [destroyed](27) a Waitrose supermarket in Banstead in Surrey as `` potentially suspicious '' .", 145 | "A Waitrose supermarket in Surrey has been [destroyed](51) in a [fire](47) .", 146 | "Waitrose in Wellington [catches](69) [fire](40)" 147 | ], 148 | "target": [ 149 | [47, 15, 8, 65, 45, 79, 60], [16], [3], [58], [72], [51, 70, 80, 46, 41, 27], [69], [0], [71, 33, 52], 150 | [22, 63, 50] 151 | ] 152 | 153 | }, 154 | { 155 | "topic_id_train": 16, 156 | "documents": [ 157 | "Two possible gang members are under [arrest](79) for the assassination-style [murder](22) of a Los Angeles County sheriff 's deputy outside his house in a gang-plagued section of the city last summer , police [said](40) Saturday .", 158 | "Sheriff ’ s Deputy [Shot](6) and [Killed](69) June 24 , 2005 A 35 - year - old sheriff ’ s deputy [working](8) with an anti - gang unit was [killed](28) Friday when he [knocked](44) on the door of a home and someone [shot](60) him in the head , authorities [said](2) .", 159 | "Jose Luis Orozco [showed](10) almost no [emotion](32) throughout the three - week [trial](53) in which he was [convicted](35) of the 2005 [murder](48) of Los Angeles County Sheriff's Deputy Jerry Ortiz .", 160 | "Judge [sentences](70) Orozco to [death](27)", 161 | "Two possible gang members were [arrested](81) in the fatal [shooting](72) of a sheriff 's deputy who was [murdered](0) as he was [getting ready](46) to go to [work](20) , officials [announced](55) Saturday .", 162 | "The gang member [shot](13) Jerry Ortiz outside a Hawaiian Gardens building in 2005 ." 163 | ], 164 | "target": [ 165 | [44], [27], [40], [48, 28, 69], [72], [20], [55], [46], [70, 35], [6, 13, 60], [8], [10], [32], [53], [22, 0], 166 | [81, 79], [2]] 167 | }, 168 | { 169 | "topic_id_train": 11, 170 | "documents": [ 171 | "Turkmenistan 's voters [are going to the polls](9) in parliamentary [elections](18) [portrayed](47) by the government as a [step](4) towards [democracy](19) in the gas-rich Central Asian nation .", 172 | "Electoral districts [established](26) in Turkmenistan due to parliamentary [elections](35)", 173 | "Voters in Turkmenistan [cast](52) ballots Sunday in a parliamentary [election](58) [hailed](17) by the government as an exercise in [democracy](49) but [dismissed](40) by critics as a sham .", 174 | "'The government of Turkmenistan [said](70) 90 per cent of eligible voters had [participated](1) in parliamentary [elections](14) Sunday , despite [boycotts](54) from opposition groups .'", 175 | "Turkeminstan's CEC [plans](13) to [hold](22) [elections](61) on December 15 2013 .", 176 | "Voters in Turkmenistan began to [cast](59) their ballots at 8 : 00 am local time ( 0300 GMT ) on Sunday in the country 's fourth parliamentary [election](3) , [said](31) reports from Ashgabat , Turkmenistan 's capital .", 177 | "The 4th parliamentary [election](48) of Turkmenistan [ended](56) Sunday evening as scheduled , [said](60) reports from Ashgabat , Turkmenistan 's capital ." 178 | ], 179 | "target": [ 180 | [60, 31], [61, 35], [19, 49], [58, 18, 3, 14, 48], [56], [17], [40], [26], [22], [13], [47], [4], [54], [70], 181 | [9, 1, 59, 52] 182 | ] 183 | } 184 | ] 185 | } 186 | -------------------------------------------------------------------------------- /files/prompts/MultiNews.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompts": [ 3 | { 4 | "id": 0, 5 | "instructions" : "In this task, you are presented with multiple news articles about related topics. Your job is to generate an extractive summary that integrates information from the provided articles. Your summary should be short and concise, that includes content only from the provided articles, avoiding any external data sources.", 6 | "comments": "" 7 | }, 8 | { 9 | "id": 1, 10 | "instructions" : "Please provide a brief, extractive summary by synthesizing only the key points from the articles provided. Focus on the main arguments and conclusions without incorporating any information from outside these texts. Keep your summary concise and directly related to the content of the documents.", 11 | "comments": "" 12 | }, 13 | { 14 | "id": 2, 15 | "instructions" : "Generate a concise extractive summary using only the information from the provided articles. Your summary should distill the most essential information, capturing the core insights without adding any external content. Aim for brevity and clarity in your summarization.", 16 | "comments": "" 17 | }, 18 | { 19 | "id": 3, 20 | "instructions" : "Please sift through the provided articles and distill their essence into a sharp, concise summary. Focus solely on the facts and key points within these texts, avoiding any embellishment or reference to external information. Your summary should read like a bullet-point list of the most critical insights.", 21 | "comments": "" 22 | }, 23 | { 24 | "id": 4, 25 | "instructions" : "You are presented with multiple news articles about related topics. Summarize the contents in a way that captures the key information in a narrative form, but strictly using the details mentioned in the provided documents. Keep it engaging yet brief.", 26 | "comments": "" 27 | }, 28 | { 29 | "id": 6, 30 | "instructions" : "Imagine you're preparing a brief for a decision-maker who has limited time. Summarize the provided documents by extracting only the most essential information. Present this in a clear, straightforward manner, focusing on the key facts and figures, and ensure that all content is directly sourced from the articles without external references.", 31 | "comments": "" 32 | }, 33 | { 34 | "id": 7, 35 | "instructions" : "Using only the details from the articles I've given you, craft a summary that distills the most important information. Avoid any interpretations or external data, and keep your summary short and direct. Emphasize the main arguments, data points, and conclusions from the documents.", 36 | "comments": "" 37 | }, 38 | { 39 | "id": 8, 40 | "instructions" : "Operate as an information synthesizer: Draw the essence from multiple articles, focusing solely on the information contained within them. Your summary should be a tight, focused digest of the articles, free from any influence of external data.", 41 | "comments": "" 42 | }, 43 | { 44 | "id": 9, 45 | "instructions" : "Scan through the provided articles and compile a summary that highlights only the most significant facts and figures, ensuring the exclusion of all external references. Aim for clarity and brevity.", 46 | "comments": "" 47 | }, 48 | { 49 | "id": 10, 50 | "instructions" : "Operate as an academic summarizer: Imagine you are creating a summary for an academic review. Extract and emphasize the most pertinent information, ensuring your summary remains true to the original texts and free of external content.", 51 | "comments": "" 52 | }, 53 | { 54 | "id": 11, 55 | "instructions" : "Condense the provided information into a compact summary that emphasizes the main points and crucial data from the documents. Exclude any external information to maintain the integrity of the sources.", 56 | "comments": "" 57 | }, 58 | { 59 | "id": 12, 60 | "instructions" : "From the provided articles, pull out the core messages and data points. Shape these into a brief, clear summary that directly reflects the content of the documents without any external additions.", 61 | "comments": "" 62 | }, 63 | { 64 | "id": 13, 65 | "instructions" : "Compile a concise summary from the news articles given, focusing only on the information contained within. Your summary should integrate the main points without adding any outside information.", 66 | "comments": "" 67 | }, 68 | { 69 | "id": 14, 70 | "instructions" : "Create a succinct extractive summary by focusing exclusively on the details provided in the articles. Avoid using any external sources and ensure the summary remains clear and to the point.", 71 | "comments": "" 72 | }, 73 | { 74 | "id": 15, 75 | "instructions" : "Produce a brief summary that distills the essential facts from the provided articles. Keep your summary strictly to the content presented in the documents, avoiding external influences.", 76 | "comments": "" 77 | }, 78 | { 79 | "id": 16, 80 | "instructions" : "Develop a concise extractive summary using only the information from the articles provided. Emphasize the main points and conclusions while avoiding the inclusion of any external data.", 81 | "comments": "" 82 | }, 83 | { 84 | "id": 17, 85 | "instructions" : "Prepare a short, integrated summary by synthesizing key points from the given news articles. Ensure that no external content is included and that the summary is clear and direct.", 86 | "comments": "" 87 | }, 88 | { 89 | "id": 18, 90 | "instructions" : "Your task is to distill the primary information from the provided articles into a concise summary. Make sure to exclude any external sources and focus strictly on the given texts.", 91 | "comments": "" 92 | }, 93 | { 94 | "id": 19, 95 | "instructions" : "Summarize the provided articles by extracting only the key information and conclusions. Your summary should be brief and must not incorporate any external data.", 96 | "comments": "" 97 | }, 98 | { 99 | "id": 20, 100 | "instructions" : "Generate a clear and brief extractive summary using just the information from the provided articles. Focus on distilling the essential points and data without referencing external content.", 101 | "comments": "" 102 | } 103 | ], 104 | "demonstrations": [ 105 | { 106 | "id_train": 33239, 107 | "documents": [ 108 | "The Iowa State Medical Examiner's announcement Thursday that Tibbetts death was a homicide resulting from multiple sharp force injuries is likely to fuel more outrage. We cannot allow these tragedies to continue. Former Republican House Speaker Newt Gingrich declared on Fox News that if Mollie Tibbetts is a household name by October, Democrats will be in deep trouble come Election Day. In the Tibbetts case, Cristhian Bahena Rivera, 24, was charged with first-degree murder Tuesday after the body of 20-year-old Tibbetts was found in a field east of her hometown, Brooklyn, Iowa. In a court filing ahead of Rivera's first court appearance on Wednesday, his lawyer, Allan Richards, sought a gag order claiming that the federal government was promoting the idea that his client was in the country illegally. Cristhian Rivera, 24, accused of killing University of Iowa student Mollie Tibbetts, is led from the courtroom after making his initial appearance on a charge of first-degree murder during at the Poweshiek County Courthouse in Montezuma, Iowa. And in an interview with NBC News on Thursday, Richards was evasive when asked for proof that Rivera was in the U.S. legally.", 109 | "Rivera quickly accepted, tweeting that Patrick claimed he was effectively an accomplice to horrifying murder of Molly Tibbetts because I beg compassion & mercy for undocumented immigrants-How dare he make so false an allegation? He is fear-mongering. Tibbetts, disappeared in July. Cristhian Bahena Rivera, has been charged for her murder. U.S. Immigration and Customs Enforcement has said that Rivera was in the country illegally, but his attorney disputes that. Rivera is a Mexican citizen." 110 | ], 111 | "target": "The Mollie Tibbetts murder continues to emerge as a flashpoint in the immigration debate, reports NBC News. The reason? US immigration officials say 24-year-old suspect Cristhian Bahena Rivera was in the country illegally, though his attorney continues to push back against that claim. Geraldo criticizes Fox: One of those criticizing Fox News for focusing so strongly on the immigration angle turns out to be one of its own hosts, Geraldo Rivera, reports the Washington Post. \"We, at this network, are putting that spin on this story,\" Counter-view: Texas Lt. Gov. Dan Patrick has been among those on the right blaming too-soft immigration policies for the murder. \"The CNNs, the MSNBCs, most of the print media in this country, and the Democrats, accomplices in the death of this young girl and the death of everyone else,\" he said on Fox." 112 | }, 113 | { 114 | "id_train": 8874, 115 | "documents": [ 116 | "In a once abandoned Hershey chocolate factory in the small town of Smiths Falls, Ontario, the largest legal marijuana producer in the world grows, trims, processes, packages, and ships weed across the Great White North. In the fall of 2016, Canopy, which trades on the Toronto Stock Exchange under the ticker WEED, became the first company in the marijuana industry to achieve elusive unicorn status. Canadians have enjoyed the ability to possess and grow small amounts of weed for medical use. In 2014, the government began licensing companies like Canopy to produce mass amounts of marijuana for patients suffering from serious diseases. Linton wanted to create a vertically integrated company — one that grows marijuana in addition to processing it for oils, gel capsules, and other products and packaging it for shipment — because it would give him better control over quality and bring down costs.", 117 | "On Bay Street, Toronto’s equivalent of Wall Street, you can now buy weed. Rather, make that WEED. Amid what it says is a growing acceptance of Canada’s burgeoning medical-marijuana industry, Canopy Growth Corp. switched to the new four-letter stock ticker on the Toronto Stock Exchange Wednesday. Canopy already exports marijuana products to Germany and Brazil. “We’re thrilled to be marketing WEED on Bay Street,” Chief Executive Officer Bruce Linton said in a statement.", 118 | "Canopy Growth Corp. jumped to a record high, and other Canadian cannabis growers gained, after unveiling a line of marijuana products for the domestic market in a partnership with rapper Snoop Dogg. Canopy became the first marijuana producer to trade on a major North American stock exchange when it graduated to the Toronto Stock Exchange in July. Canopy CEO Bruce Linton has ambitious expansion plans for the company as it grows internationally and Canada inches closer to full legalization of the drug that’s now allowed for medical use. The government of Prime Minister Justin Trudeau, who campaigned on a legalization pledge, is due to receive a report in November recommending how Canada may move forward." 119 | ], 120 | "target": "Factory in a small town in Ontario once spewed out Hershey's chocolate, and now it's home to the world's largest legal marijuana producer, which supplies pot across Canada and exports to Germany and Brazil. Canopy Growth—found on the Toronto Stock Exchange. The company supply pot to almost half of Canadian medical marijuana patients, and also grows, processes, and packages the product. With control over its supply chain, Canopy Growth is able to process everything from oils to gel capsules." 121 | }, 122 | { 123 | "id_train": 38373, 124 | "documents": [ 125 | "Darden Restaurants, owner of the Italian restaurant chain Olive Garden, says its unlimited breadsticks and salad dressings are working out just fine with consumers. The casual-dining restaurant operator is on the defensive after investor Starboard Value LP last week disclosed nearly 300-page slide presentation that detailed a potential turnaround plan for Darden’s Olive Garden chain. Among the complaints: Olive Garden was serving too many breadsticks, adding too much dressing to its salads, and not serving enough alcohol. Among the most egregious complaints: Starboard lamented Olive Garden wasn’t adding salt to the water used to cook the pasta, asking How does the largest Italian dining concept in the world not salt the water for pasta? Darden fired back this week, saying in a 24-page slide presentation that Starboard’s suggestions are not based on reality.", 126 | "Olive Garden is defending its practice of giving customers as many breadsticks as they want, saying the policy conveys Italian generosity. The remark is part of a response by the chain's parent company, Darden Restaurants Inc., to a nearly 300-page criticism released by hedge fund Starboard Value last week. Starboard took Olive Garden and its management to task for a variety of issues, including its liberal distribution of breadsticks, its failure to salt the water used to boil its pasta, and even the length of the asparagus it serves. Darden's 24-page response states that the company is already implementing a variety of strategies to improve Olive Garden's performance. Starboard is lobbying to gain control of Darden's board of directors at the company's annual meeting which is based in Orlando, Florida, has struggled to boost sales at Olive Garden with the growing popularity of chains such as Chipotle. Under pressure to boost results, Darden recently sold off Red Lobster, which was doing even worse than Olive Garden. As for its breadsticks, Starboard said last week that Olive Garden was being wasteful because servers weren't sticking to the policy of providing one breadstick per customer, plus an extra for the table. " 127 | ], 128 | "target": "Olive Garden has fired back in the battle of the breadsticks. The Florida-based chain's parent company says the unlimited breadstick policy—one of many things criticized by hedge fund Starboard Value—is an example of Italian generosity. In its response to Starboard's 300-page criticism, Darden Restaurants says the salads Starboard slams for being overfilled and overdressed are a big hit that inspires loyalty in customers. The 24-page report says a major menu revamp is underway and things like ordering via tabletop tablets are in the works, although it doesn't address some of Starboard's criticisms." 129 | }, 130 | { 131 | "id_train": 33055, 132 | "documents": [ 133 | "Nearly 300 aviation security officers are suing the city of Chicago and the state of Illinois, asserting that they were wrongfully stripped of their law enforcement status following the widely publicized incident involving Dr. David Dao being forcibly removed from a United Airlines flight in April 2017. These officers, recognized as law enforcement since 1993 and trained at the Chicago Police Academy or the Cook County Sheriff’s Training Academy, lost their privileges due to political pressure after the incident. The city’s Aviation Commissioner and the state board declared that aviation officers were no longer considered law enforcement at airports. The officers' lawsuit contends that removing their law enforcement status unfairly impacts their work history and future job opportunities.", 134 | "James Long, a former Chicago aviation security officer, is suing Chicago's Department of Aviation (CDA), its commissioner, and the city. He claims he was not properly trained on the use of force, which he argues contributed to the incident involving Dr. David Dao. Dr. Dao was forcibly removed from a United Airlines flight, resulting in significant injuries, and later settled with the airline. The incident led to a policy change by United Airlines regarding the allocation of seats to crew members. But for the CDA's negligence and failure to train Long how to respond to an escalating situation with an Airline Passenger, he would not have acted in the manner he did, which resulted in his termination, the complaint says, according to newspaper" 135 | ], 136 | "target": "The Chicago aviation security officer who dragged a physician off a United Airlines flight last year is suing the airline and the city, claiming that their negligence led to his firing. In the lawsuit, James Long argues that United should have known that calling security to remove David Dao, 69, from the overbooked flight would require the use of physical force. The complaint also states that although Long completed five months of training as a police recruit, the Chicago Department of Aviation failed to provide training on the level of force continuum. With that training. In a separate lawsuit, some 300 aviation security officers are suing Chicago and the state of Illinois, complaining that they were stripped of their law enforcement officer status after the Dao incident, the Chicago Tribune reports. The officers—who completed police or sheriff's department training and were sworn in as law enforcement officers—argue that while the state has the right to strip the agency of its policing powers, their work history as law enforcement officers shouldn't be expunged." 137 | }, 138 | { 139 | "id_train": 10423, 140 | "documents": [ 141 | "Local and international human rights groups have criticized Presidential Spokesperson Harry Roque for claiming they are manipulated by drug lords to oppose the Duterte administration's drug policies. Human Rights Watch described these allegations as government intimidation tactics meant to undermine Philippine human rights activists who challenge the administration's approach to the rule of law and potential crimes against humanity. Roque defended his statements, suggesting that such groups should not overreact and politicize the issue for media attention. Meanwhile, the local group Karapatan accused the administration of fabricating stories to blame human rights organizations for its failures in addressing the drug problem and for tarnishing the nation’s reputation. They suggest that these accusations serve to justify attacks on activists or to dodge accountability under domestic and international rights laws. Despite official figures indicating over 4000 deaths in drug operations, human rights organizations estimate the toll could be as high as 13,000.", 142 | "Philippine law enforcement agencies still have no proof yet on allegations made by two government officials that human rights groups may have become the “unwitting tools” of drug lords, police and drug enforcement agency officials said on Tuesday. Carreon said more than 123,000 drug suspects had been arrested, which he said showed anti-drugs operations were not about killings. He said law enforcement agencies welcomed any criticism from rights groups and allowed them to observe anti-drug operations to prove that everything was done according to the rule of law. Duterte’s spokesman and foreign secretary did not present any evidence when they told reporters drug lords were using human rights groups to undermine the policy, statements against which the Human Rights Watch protested and said were “shameful” and risked provoking violence. The anti-narcotics campaign has raised international alarm and drawn criticism from some U.N. representatives." 143 | ], 144 | "target": "Human rights groups in the Philippines are contending with government claims that they are inadvertently assisting drug lords by highlighting the increasing number of deaths in President Rodrigo Duterte's drug war. Following a particularly violent week where 13 people were killed in one night, presidential spokesman Harry Roque accused these groups of hindering the government's efforts. This aligns with earlier statements from Foreign Secretary Alan Cayetano, who criticized the motives of these groups. Despite official reports of around 4,100 deaths in police shootouts and 124000 arrests since June 2016, human rights organizations argue that the actual death toll is significantly higher and that the government's accusations aim to diminish their credibility and ignore global criticism of Duterte's policies. The International Criminal Court is investigating these policies as potential crimes against humanity. Human Rights Watch described Roque's comments as dangerously misleading, and local group Karapatan criticized the government for fabricating stories to malign them." 145 | } 146 | ] 147 | } 148 | -------------------------------------------------------------------------------- /files/prompts/OpenASP.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompts": [ 3 | { 4 | "id": 0, 5 | "instructions": "In this task you are required to generate an aspect-based summary of a set of documents related the same topic. Please write a short, concise aspect-based summary, only summarize content from the above documents, avoiding any external data sources.", 6 | "comments": "" 7 | }, 8 | { 9 | "id": 1, 10 | "instructions": "Your goal is to create a short, concise aspect-based summary of the given documents. Summarize the key points accurately, using only the information from these documents and excluding any external sources.", 11 | "comments": "" 12 | }, 13 | { 14 | "id": 2, 15 | "instructions": "Produce a brief, aspect-based summary of the collection of documents on the same topic. Ensure your summary is concise and derived only from the provided documents, avoiding any external data sources.", 16 | "comments": "" 17 | }, 18 | { 19 | "id": 3, 20 | "instructions": "Your task is to generate a detailed yet concise aspect-based summary from a collection of documents that focus on the same topic. Begin by thoroughly examining each document to understand the main aspects and themes. Then, synthesize this information into a coherent summary that highlights the significant points. Make sure your summary is short and derived exclusively from the content of the provided documents, without incorporating any external data.", 21 | "comments": "" 22 | }, 23 | { 24 | "id": 4, 25 | "instructions": "Given a set of documents related to a specific topic, generate a short, concise aspect-based summary. Ensure that the summary is based solely on the content of the documents provided", 26 | "comments": "" 27 | }, 28 | { 29 | "id": 5, 30 | "instructions": "You will receive several documents on the same topic. Your task is to write a brief aspect-based summary, using only the information from the provided documents and excluding any external sources.", 31 | "comments": "" 32 | }, 33 | { 34 | "id": 6, 35 | "instructions": "You are tasked with generating an aspect-based summary of several documents. Summarize the content briefly and accurately, using only the information from the documents give", 36 | "comments": "" 37 | }, 38 | { 39 | "id": 7, 40 | "instructions": "In this task, you are required to create an aspect-based summary of a set of documents all related to the same topic. Carefully read through each document and identify the key aspects discussed. Summarize these aspects in a concise manner, ensuring that your summary captures the essential points. It is crucial to base your summary solely on the provided documents, avoiding any external information or references. ", 41 | "comments": "" 42 | }, 43 | { 44 | "id": 8, 45 | "instructions": "You are tasked with producing an aspect-based summary for a series of documents related to the same topic. Start by analyzing each document to identify the critical aspects covered. Your goal is to condense this information into a clear and concise summary, ensuring that you accurately represent the main points. The summary should be brief and entirely based on the provided documents, with no inclusion of external sources or data.", 46 | "comments": "" 47 | }, 48 | { 49 | "id": 9, 50 | "instructions": "Generate a concise aspect-based summary of the given documents. Focus on summarizing the content based solely on the information from these documents, avoiding any external sources.", 51 | "comments": "" 52 | }, 53 | { 54 | "id": 10, 55 | "instructions": "Create a concise aspect-based summary for the provided set of documents. Focus on the main aspects and themes discussed in these documents, ensuring that your summary is based entirely on the content of the provided documents and excludes any external sources.", 56 | "comments": "" 57 | }, 58 | { 59 | "id": 11, 60 | "instructions": "Produce a short and precise aspect-based summary of the given documents. Identify the key aspects discussed in these documents and synthesize a concise summary based solely on the provided content.", 61 | "comments": "" 62 | }, 63 | { 64 | "id": 12, 65 | "instructions": "You will receive a collection of documents focused on the same topic. Your task is to create an aspect-based summary that highlights the key aspects discussed in these documents. Ensure your summary is brief and does not include any external information.", 66 | "comments": "" 67 | }, 68 | { 69 | "id": 13, 70 | "instructions": "You are provided with multiple documents related to a single topic. Your task is to generate an aspect-based summary that captures the main aspects discussed in these documents. Ensure your summary is concise and solely based on the provided texts.", 71 | "comments": "" 72 | }, 73 | { 74 | "id": 14, 75 | "instructions": "You are tasked with generating an aspect-based summary of several documents on the same topic. Carefully review each document, identify the main aspects, and write a brief summary that captures these aspects using only the provided documents.", 76 | "comments": "" 77 | }, 78 | { 79 | "id": 15, 80 | "instructions": "Your role is to create an educational summary for students using a collection of documents on the same topic. Focus on the main aspects that would help students understand the core concepts discussed in the documents. Write a short, clear aspect-based summary, relying exclusively on the provided texts.", 81 | "comments": "" 82 | }, 83 | { 84 | "id": 16, 85 | "instructions": "Imagine you are preparing a briefing for a busy executive who needs to understand the key aspects of several documents quickly. Summarize the most important points from these documents in a concise manner, ensuring your aspect-based summary is derived entirely from the content of the provided documents and avoids any external information.", 86 | "comments": "" 87 | }, 88 | { 89 | "id": 17, 90 | "instructions": "As an advanced AI tasked with summarizing documents, your goal is to generate an aspect-based summary. Think of yourself as a summarization expert, extracting the most critical aspects from the documents provided. Craft a concise summary that highlights these key aspects, ensuring it is based solely on the given documents.", 91 | "comments": "" 92 | }, 93 | { 94 | "id": 18, 95 | "instructions": "Imagine you are a journalist tasked with writing a summary article based on a series of documents related to a single topic. Identify the key aspects discussed in these documents and compose a brief, coherent summary that encapsulates the main points without introducing any external information.", 96 | "comments": "" 97 | }, 98 | { 99 | "id": 19, 100 | "instructions": "Your task is to act as a knowledge distiller, creating a concise aspect-based summary from a series of documents on the same topic. Focus on identifying and summarizing the critical aspects discussed in these documents, ensuring your summary is brief and based exclusively on the provided content.", 101 | "comments": "" 102 | }, 103 | { 104 | "id": 20, 105 | "instructions": "You are an AI assistant tasked with providing a summary for a set of documents related to a specific topic. Focus on the key aspects and themes discussed in these documents. Create a summary that captures these aspects in a concise manner, ensuring that your summary is based solely on the provided documents and excludes any external information.", 106 | "comments": "" 107 | } 108 | ], 109 | "demonstrations": [ 110 | { 111 | "id_train": "Ron Paul comments/KROG/mnews-train.135533.A", 112 | "index": 130, 113 | "aspect_label": "Ron Paul comments", 114 | "documents": [ 115 | "Texas Gov. Rick Perry, who is struggling in the polls, offered a lively retort when asked if his poor debate performances mean he's not ready to debate President Obama. Your Browser DoesNot Support IFrames. He compared himself to Tim Tebow, the Denver Broncos quarterback who is controversial because of his outward displays of Christian faith. \"I'm kinda getting to where I like these debates,\" Perry said. \"I hope Obama and I debate a lot. I'll get there early. We will get it on, and we will talk about our differences.\" \"There are a lot of folks who said Tim Tebow wouldn't be a very good NFL quarterback. That he doesn't have the right throwing mechanisms, or he's not playing the game right. He won two national championships and that looked pretty good, and we were the national champions in job creation back in Texas... I am the Tim Tebow of the Iowa caucuses.\" More on PostPolitics", 116 | "Republican presidential candidates at the Iowa debate. (JIM YOUNG/REUTERS) \"The courts have become grotesquely dictatorial, far too powerful and I think frankly arrogant in their misreading of the American people,\" the former House speaker said. \"I would, just like [former presidents] Jefferson, Jackson, Lincoln and FDR, I would be prepared to take on the judiciary [branch of government] if it did not restrict itself in what it was doing.\" Agreeing with Gingrich, Rep. Michele Bachmann (R-Minn.) said \"if we give to the courts the right to make law than the people will have lost their representation.\" More on PostPolitics" 117 | ], 118 | "target": "Michele Bachmann agrees, but a slightly stunned Ron Paul responds that a president stripping power from judges presents a \"real problem\" in the separation of powers between the executive branch and the judiciary.\nBachmann and Paul go for each other's jugular over Iran.\n\"You cannot solve these problems with war,\" Paul responds, his voice rising." 119 | }, 120 | { 121 | "id_train": "Gingrich comments/KROG/mnews-train.135533.A", 122 | "index": 275, 123 | "aspect_label": "Gingrich comments", 124 | "documents": [ 125 | "Texas Gov. Rick Perry, who is struggling in the polls, offered a lively retort when asked if his poor debate performances mean he's not ready to debate President Obama. Your Browser DoesNot Support IFrames. He compared himself to Tim Tebow, the Denver Broncos quarterback who is controversial because of his outward displays of Christian faith. \"I'm kinda getting to where I like these debates,\" Perry said. \"I hope Obama and I debate a lot. I'll get there early. We will get it on, and we will talk about our differences.\" \"There are a lot of folks who said Tim Tebow wouldn't be a very good NFL quarterback. That he doesn't have the right throwing mechanisms, or he's not playing the game right. He won two national championships and that looked pretty good, and we were the national champions in job creation back in Texas... I am the Tim Tebow of the Iowa caucuses.\" More on PostPolitics", 126 | "Republican presidential candidates at the Iowa debate. (JIM YOUNG/REUTERS) \"The courts have become grotesquely dictatorial, far too powerful and I think frankly arrogant in their misreading of the American people,\" the former House speaker said. \"I would, just like [former presidents] Jefferson, Jackson, Lincoln and FDR, I would be prepared to take on the judiciary [branch of government] if it did not restrict itself in what it was doing.\" Agreeing with Gingrich, Rep. Michele Bachmann (R-Minn.) said \"if we give to the courts the right to make law than the people will have lost their representation.\" More on PostPolitics" 127 | ], 128 | "target": "Rivals and moderators put barely-front-runner Newt Gingrich in the hot seat, and his response was to be more rambunctious than ever.\nSome highlights: The biggest applause of the night follows Gingrich's threat to strip federal judges of their power if he becomes president because the courts are \"grotesquely\" dictatorial.\nThe judiciary is \"far too powerful and arrogant in their misreading of the American people,\" he says.\nThat's \"just not true,\" Gingrich snaps, saying he \"never lobbied under any\" circumstance.\n\"People ought to have facts before they go out with allegations.\"\nGingrich slams President Obama as a \"Saul Alinsky radical\" as the candidates explain how they would get Congress to work with the White House." 129 | }, 130 | { 131 | "id_train": "Award presented/KROG/mnews-train.132173.A", 132 | "index": 321, 133 | "aspect_label": "Debate skills", 134 | "documents": [ 135 | "Perry: Good enough debater to beat Obama Rick Perry reiterated his frustration Sunday with a presidential nominating process that puts too much emphasis on debate performance, the morning after his campaign confirmed that he will attend at least five more in his bid for the Republican nomination. The Texas governor said 18 debates is way too many because they take an incredible amount of time and preparation. When you take a look at the debates, I readily admit I'm not the best debater, he said on Fox News Sunday. With as many debates as we've got coming up, I may be a pretty good debater before it's all said and done. We've got a great debater, a smooth politician, in the White House right now. That's not what we want right now, Perry said. The governor said he's confident he can draw a clear bright line if he is debating President Barack Obama during the general election. I think I am going to be able to stand on that stage and draw a clear contrast with Barack Obama, he said.", 136 | "Fox News Sunday host Chris Wallace needled Texas Gov. Rick Perry on Sunday for pledging in his first paid campaign ad to create at least 2.5 million jobs as president. Two-and-a-half million jobs is terrible, the host told the Republican presidential candidate. We would roughly need 6 million jobs in the first four years just to stay even with population growth. So two-and-a-half-million jobs, the unemployment rate would increase! Jimmy Carter created 10.5 million jobs in his first four years. Perry said he would get criticized for not being realistic if he said he was going to create 10.5 million jobs. Let me tell you, any job at this particular point in time helps, he said. Wallace said the national unemployment rate would increase if Perry merely met the goal he set. You give this plan a chance, Perry told him, live in Austin." 137 | ], 138 | "target": "While freely admitting 'I'm not the greatest debater' in the GOP field, Rick Perry said today that he can go mano a mano with the debater-in-chief and come out ahead: “We’ve got a great debater, a smooth politician, in the White House right now. That’s not what we want right now. I think I am going to be able to stand on that stage and draw a clear contrast with Barack Obama." 139 | }, 140 | { 141 | "id_train": "Performances/KROG/mnews-train.139540.A", 142 | "index": 334, 143 | "aspect_label": "Performances", 144 | "documents": [ 145 | "Although the 2012 Grammy Awards will no doubt be overshadowed by Saturday's passing of music legend Whitney Houston, the show must go on. Even with the somber overtones, musicians and celebrities will still flock to the award show to celebrate the icon and their own achievements over the past year. Giulianna Rancic, who recently underwent a double mastectomy after her battle with breast cancer, was glowing in a black strapless cocktail dress. And as the red carpet began to heat up, the E! News host caught up with indie pop band Foster The People who will perform later in the broadcast with members of The Beach Boys and Maroon 5 in what will be a historic gathering of talent -- the surviving Beach Boys members have not performed together in 20 years, according to MTV. Bringing a bit of fun to the red carpet are Ellen Degeneres' pint-sized \"Ellen\" correspondents Sophia Grace and Rosie -- the tutu-clad duo who became YouTube sensations turned daytime darlings after belting out Nicki Minaj's \"Super Bass. \" Check back for more details on the 54th Grammy Awards red carpet arrivals. Check out the Grammys red carpet scene below:", 146 | "Sing it, Bruce! Bruce Springsteen and his E Street band have released the video for \"We Take Care Of Our Own,\" off their upcoming album \"Wrecking Ball. \" Shot in both black and white and color, Springsteen takes one for the working class with a factory rock-out session and begs the questions, \"Where's the works that has set my hands, my soul free? Where's the spirit that will reign, reign over me? Where's the promise from sea to shining sea?\" Springsteen recently announced that he had hired late band member Clarence Clemons' nephew, Jake Clemons, to join the band on saxophone. \"Wrecking Ball,\" which draws inspiration from the Occupy movement, hits stores on March 6th." 147 | ], 148 | "target": "The Grammys got started a performance.\nBruce Springsteen and the E Street Band did a rousing rendition of \"We Take Care of Our Own\".\nThen: Nope, it's still not time for an award: Bruno Mars takes the stage, then the first presenters appear—20 minutes into the ceremony.\nBut Alicia Keys and Bonnie Raitt first perform a tribute to Etta James." 149 | }, 150 | { 151 | "id_train": "Karl Rove/KROG/mnews-train.134832.A", 152 | "index": 327, 153 | "aspect_label": "Karl Rove", 154 | "documents": [ 155 | "4:31 January 6, 2012 Romney's post-Iowa momentum real? A few months ago, people really didn't seem to like Mitt Romney. Now, after the Iowa caucuses, it seems the tide has changed. Or has it? Nancy Cordes and Jim Axelrod speak with Time magazine columnist Joe Klein and CBS News political director John Dickerson about Romney's rise.", 156 | "Former White House spokesman and Obama campaign adviser Robert Gibbs was gleeful Tuesday over former House speaker Newt Gingrich's rising stature and the prospect of Donald Trump moderating a Republican presidential debate, likening it to \"shooting fish in a barrel.\" \"Cancel Christmas. Just get ready for the Donald Trump debate. ... Pay-per-view. Pay-per-view. We ought to have this, the new rule, I think you've heard, you have to have your birth certificate to participate in the debate. It's like shooting fish in a barrel,\" said Gibbs on MSNBC's \"Morning Joe,\" mocking Trump's fascination with Obama's birth certificate. Text Size - + reset Gibbs: 'Shooting fish in a barrel' POLITICO 44 \"I've never seen so many people fall over a guy who FOX's own poll showed that if you got his endorsement, 5 percent of the people would be more in favor of you and 30 percent of the people would be less in favor of you,\" he added. The former White House spokesman continued through with what has been the Obama campaign's strategy recently -- to subtly compliment Republican candidates who are not Mitt Romney, and to minimize Romney's own chances. Gingrich \"has done very well in these debates. He is very glib. I think he's done it in a way that's interesting ... not by trying to tear any one or two people down but by trying just to distinguish himself in these debates,\" he said. \"I think the Romney campaign has probably made some real fundamental strategic mistakes about not playing in Iowa a lot earlier,\" Gibbs continued." 157 | ], 158 | "target": "Karl Rove: Also on Fox News, he wondered last night, \"How can we have any confidence [Trump is] going to be impartial in his questions?\nWe’ve got a guy who’s not only saying, I am going to make a decision about who I am going to endorse shortly after this debate and I am already leaning some way and I may run myself.\nAnd we expect him to be the impartial moderator at this debate?\"" 159 | } 160 | ] 161 | } 162 | -------------------------------------------------------------------------------- /model_wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/model_wrappers/__init__.py -------------------------------------------------------------------------------- /model_wrappers/hf_pipline_wrap.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers import pipeline, logging 3 | import torch 4 | 5 | logging.set_verbosity_info() 6 | 7 | 8 | class HfPipelineWrap: 9 | max_new_tokens = 500 10 | 11 | def __init__(self, model_name, temperature, batch_size, torch_dtype=torch.bfloat16, load_in_4_bit=False, load_in_8_bit=False): 12 | self.model_name = model_name 13 | print(f"loading {model_name}") 14 | # model_kwargs = {"torch_dtype": torch_dtype, "device_map": "auto", 15 | # "load_in_4bit": load_in_4_bit, "load_in_8bit": load_in_8_bit, "attn_implementation":"flash_attention_2"} 16 | model_kwargs = {"torch_dtype": torch_dtype, "device_map": "auto", 17 | "load_in_4bit": load_in_4_bit, "load_in_8bit": load_in_8_bit} 18 | 19 | print(model_kwargs) 20 | self.pipe = pipeline("text-generation", model=model_name, model_kwargs=model_kwargs, batch_size=batch_size) 21 | self.temperature = temperature 22 | 23 | if model_name != "meta-llama/Llama-3.1-8B-Instruct": 24 | self.pipe.tokenizer.padding_side = "left" 25 | # print(self.pipe.model.config.eos_token_id) 26 | self.pipe.tokenizer.pad_token_id = self.pipe.model.config.eos_token_id 27 | else: 28 | # self.terminators = [ 29 | # self.pipe.tokenizer.eos_token_id, 30 | # self.pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>") 31 | # ] 32 | # pipeline.tokenizer.pad_token_id = pipeline.model.config.eos_token_id[0] 33 | self.pipe.tokenizer.pad_token_id = self.pipe.model.config.eos_token_id[0] 34 | # print(self.pipe.model) 35 | 36 | def get_max_window(self): 37 | return self.pipe.model.config.max_position_embeddings 38 | 39 | def batch(self, prompts, num_truncation_tokens): 40 | encoded_input = [self.pipe.tokenizer.apply_chat_template( 41 | p, truncation=True, max_length=num_truncation_tokens-self.max_new_tokens) 42 | for p in prompts] 43 | if not self.model_name == "google/gemma-2-9b-it": 44 | decoded_input = self.pipe.tokenizer.batch_decode(encoded_input, skip_special_tokens=True) 45 | else: 46 | decoded_input = prompts 47 | 48 | if self.model_name != "meta-llama/Llama-3.1-8B-Instruct": 49 | outputs = self.pipe( 50 | decoded_input, 51 | temperature=self.temperature, 52 | do_sample=True, 53 | pad_token_id=self.pipe.tokenizer.eos_token_id, 54 | max_new_tokens=self.max_new_tokens, 55 | num_workers = 5 56 | ) 57 | else: 58 | outputs = self.pipe( 59 | decoded_input, 60 | temperature=self.temperature, 61 | do_sample=True, 62 | pad_token_id=self.pipe.tokenizer.eos_token_id, 63 | max_new_tokens=self.max_new_tokens, 64 | num_workers=5 65 | ) 66 | 67 | 68 | 69 | only_responses = [output[0]["generated_text"] for i, output in enumerate(outputs)] 70 | return only_responses 71 | 72 | 73 | -------------------------------------------------------------------------------- /model_wrappers/model_wrap.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Model: 4 | def __init__(self, model): 5 | self.model = model 6 | 7 | def get_max_window(self): 8 | raise NotImplementedError("This method needs to be implemented by subclasses") 9 | 10 | def batch(self, prompts, num_truncation_tokens): 11 | raise NotImplementedError("This method needs to be implemented by subclasses") 12 | 13 | 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==1.1.1 3 | aiohappyeyeballs==2.4.4 4 | aiohttp==3.11.10 5 | aiosignal==1.3.1 6 | annotated-types==0.7.0 7 | async-timeout==5.0.1 8 | attrs==24.2.0 9 | certifi==2024.8.30 10 | charset-normalizer==3.4.0 11 | click==8.1.7 12 | datasets==3.1.0 13 | dill==0.3.8 14 | eval_type_backport==0.2.0 15 | filelock==3.16.1 16 | frozenlist==1.5.0 17 | fsspec==2024.9.0 18 | huggingface-hub==0.26.3 19 | idna==3.10 20 | Jinja2==3.1.4 21 | joblib==1.4.2 22 | markdown-it-py==3.0.0 23 | MarkupSafe==3.0.2 24 | mdurl==0.1.2 25 | mpmath==1.3.0 26 | multidict==6.1.0 27 | multiprocess==0.70.16 28 | networkx==3.4.2 29 | nltk==3.9.1 30 | numpy==2.1.3 31 | nvidia-cublas-cu12==12.4.5.8 32 | nvidia-cuda-cupti-cu12==12.4.127 33 | nvidia-cuda-nvrtc-cu12==12.4.127 34 | nvidia-cuda-runtime-cu12==12.4.127 35 | nvidia-cudnn-cu12==9.1.0.70 36 | nvidia-cufft-cu12==11.2.1.3 37 | nvidia-curand-cu12==10.3.5.147 38 | nvidia-cusolver-cu12==11.6.1.9 39 | nvidia-cusparse-cu12==12.3.1.170 40 | nvidia-nccl-cu12==2.21.5 41 | nvidia-nvjitlink-cu12==12.4.127 42 | nvidia-nvtx-cu12==12.4.127 43 | packaging==24.2 44 | pandas==2.2.3 45 | pillow==10.4.0 46 | propcache==0.2.1 47 | psutil==6.1.0 48 | pyarrow==18.1.0 49 | pydantic==2.10.3 50 | pydantic_core==2.27.1 51 | Pygments==2.18.0 52 | python-dateutil==2.9.0.post0 53 | pytz==2024.2 54 | PyYAML==6.0.2 55 | regex==2024.11.6 56 | requests==2.32.3 57 | rich==13.9.4 58 | rouge_score==0.1.2 59 | safetensors==0.4.5 60 | shellingham==1.5.4 61 | six==1.17.0 62 | sympy==1.13.1 63 | tabulate==0.9.0 64 | tiktoken==0.8.0 65 | together==1.3.5 66 | tokenizers==0.21.0 67 | torch==2.5.1 68 | tqdm==4.67.1 69 | transformers==4.47.0 70 | triton==3.1.0 71 | typer==0.13.1 72 | typing_extensions==4.12.2 73 | tzdata==2024.2 74 | urllib3==2.2.3 75 | xxhash==3.5.0 76 | yarl==1.18.3 77 | -------------------------------------------------------------------------------- /scripts/Lama_models.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | ## Llama - 70B: 15 | ## metlama/Ma-leta-Llama-3-70B-Instruct-Lite - More efficient 16 | ## "meta-llama/Meta-Llama-3-70B-Instruct-Turbo" - More efficient 17 | ## meta-llama/Llama-3-70b-chat-hf - for Chat with human -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaharl6000/MoreDocsSameLen/37fa664353f81fd11c775244064f33bb48e0c218/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/check_pkl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | if __name__ == "__main__": 17 | pass 18 | data = np.load("/cs/labs/tomhope/nirm/MusiQue/pkl_files_original_questions/datasets_orig_q_Expanded_Set.pkl", allow_pickle= True) 19 | print(data) 20 | file_path = "/cs/labs/tomhope/nirm/MusiQue/pkl_files_original_questions/datasets_orig_q_replaced_Hybrid8.pkl" 21 | file_size_bytes = os.path.getsize(file_path) 22 | file_size_mb = file_size_bytes / (1024 * 1024) 23 | 24 | print(f"Size of the local file: {file_size_mb:.2f} MB") 25 | # data = np.load("/cs/labs/tomhope/nirm/MusiQue/pkl_files/datasets_extendedSet.pkl", allow_pickle=True) 26 | # print(data) 27 | 28 | # data = np.load("/cs/labs/tomhope/nirm/MusiQue/pkl_files/datasets_fullSet.pkl", allow_pickle=True) 29 | # print(data) 30 | 31 | # data = np.load("/cs/labs/tomhope/nirm/MusiQue/pkl_files/datasets_noDocSet.pkl", allow_pickle=True) 32 | # print(data) 33 | 34 | # data = np.load("/cs/labs/tomhope/nirm/MusiQue/pkl_files/datasets_pad_token_mistral.pkl", allow_pickle=True) 35 | # print(data) 36 | 37 | # data = np.load("/cs/labs/tomhope/nirm/MusiQue/pkl_files/datasets_raplacedSet.pkl", allow_pickle=True) 38 | # print(data) 39 | -------------------------------------------------------------------------------- /scripts/compare_results_jsons.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from glob import glob 5 | 6 | import numpy as np 7 | 8 | MODELS_NAME_MAPPING = { 9 | "Meta-Llama-3-8B-Instruct": "Llama3-8B", 10 | "Meta-Llama-3-70B-Instruct": "Llama3-70B", 11 | "gemma-1.1-7b-it": "Gemma1.1-7B", 12 | "gemma-1.1-2b-it": "Gemma1.1-2B", 13 | "Mistral-7B-Instruct-v0.2": "Mistral-7B", 14 | "Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B", 15 | "Mixtral-8x22B-Instruct-v0.1": "Mixtral-8x22B", 16 | } 17 | 18 | class Eval: 19 | def __init__(self, id_key, predictions_dir, out_path): 20 | self.id_key = id_key 21 | self.predictions_dir = predictions_dir 22 | self.out_path = out_path 23 | 24 | def _evaluate(self, data, model_name, sample_index): 25 | raise NotImplementedError("This method needs to be implemented by subclasses") 26 | 27 | def _evaluate_new(self, data, model_name, sample_index): 28 | raise NotImplementedError("This method needs to be implemented by subclasses") 29 | 30 | def eval_all_in_dir(self): 31 | print("Results will be saved in:", self.out_path) 32 | all_results = {} 33 | sample_lengths = {} 34 | if os.path.exists(self.out_path): 35 | with open(self.out_path, 'rt') as f: 36 | existing = json.load(f) 37 | all_results = existing["models"] 38 | sample_lengths = existing["sample_lengths"] 39 | for f_name in glob(f'{self.predictions_dir}/*/*.json'): 40 | sample_name = f_name.replace(self.predictions_dir, "").replace(".json", "") 41 | model = sample_name[:-2].split(os.sep)[-1] 42 | model_name = MODELS_NAME_MAPPING[model] 43 | sample_index = int(sample_name[-1]) 44 | print("Evaluating model:", model_name, "sample:", sample_index) 45 | if model_name in all_results and sample_index in all_results[model_name]["run_index"]: 46 | print("Skipping", model_name, sample_index) 47 | continue 48 | if model_name not in all_results: 49 | all_results[model_name] = {"scores": [], "run_index": [], "ids": []} 50 | all_results[model_name]["run_index"].append(sample_index) 51 | 52 | with open(f_name, 'rt') as f: 53 | predictions = json.load(f) 54 | 55 | current_ids = [] 56 | for pred in predictions: 57 | id_sample = str(pred[self.id_key]) 58 | current_ids.append(id_sample) 59 | if id_sample not in sample_lengths: 60 | length = len(pred["final_msgs"][0]['content']) 61 | sample_lengths[id_sample] = length 62 | all_results[model_name]["ids"].append(current_ids) 63 | results = self._evaluate_new(predictions, model_name, sample_index) 64 | all_results[model_name]["scores"].append(results) 65 | out_dict = {"sample_lengths": sample_lengths, "models": all_results} 66 | with open(self.out_path, 'wt') as f: 67 | json.dump(out_dict, f) 68 | 69 | @staticmethod 70 | def get_only_response(prediction): 71 | pred = prediction["prediction"] 72 | if '[/INST]' in pred: 73 | pred = pred[pred.find("[/INST]")+len("[/INST]"):].strip() 74 | elif "So, the answer is:" in pred: 75 | pred = pred[pred.rfind("So, the answer is:") + len("So, the answer is:"):].strip() 76 | if pred == "": 77 | # print(prediction["prediction"]) 78 | pred = "No answer" 79 | elif "Aspect-based summary:" in pred: 80 | pred = pred[pred.rfind("Aspect-based summary:") + len("Aspect-based summary:"):].strip() 81 | else: 82 | pred = pred[pred.rfind("*Answer*:") + len("*Answer*:"):].strip() 83 | return pred 84 | 85 | class QA(Eval): 86 | 87 | def __init__(self, id_key, predictions_dir, out_path): 88 | self.correlations = [] 89 | super().__init__(id_key, predictions_dir, out_path) 90 | 91 | def _evaluate(self, predictions, model_name, sample_index): 92 | follow_format = [] 93 | for sample in predictions: 94 | gt = sample["target"] 95 | pred = self.postprocess(sample) 96 | follow_format.append(type(pred) is dict) 97 | self.correlations.append([np.mean(follow_format), np.mean([])]) 98 | metrics = {"all_f1": []} 99 | return metrics 100 | 101 | def _evaluate_new(self, predictions, model_name, sample_index): 102 | id_to_answers = {} 103 | for sample in predictions: 104 | gt = sample["target"] 105 | id_sample = sample["id"] 106 | pred = self.postprocess(sample) 107 | # gt, pred = self.parse(gt, pred) 108 | pred = extract_answer_content(pred) 109 | 110 | if id_sample not in id_to_answers: 111 | id_to_answers[id_sample] = {} 112 | id_to_answers[id_sample] = pred 113 | return id_to_answers 114 | 115 | def parse(self, ground_truth_answer, predicted_answer): 116 | gt = self._extract_answer_from_dict(ground_truth_answer) 117 | if type(predicted_answer) is dict: 118 | pred = self._extract_answer_from_dict(predicted_answer) 119 | else: 120 | pred = predicted_answer 121 | return gt, pred 122 | 123 | 124 | @staticmethod 125 | def _extract_answer_from_dict(answer_dict): 126 | if answer_dict != "**": 127 | if "is_answerable" in answer_dict: 128 | if answer_dict["is_answerable"]: 129 | answer = answer_dict["answer_content"] 130 | else: 131 | answer = answer_dict["is_answerable"] 132 | return answer 133 | else: 134 | return "Not answerable" 135 | 136 | def postprocess(self, pred): 137 | only_response = self.get_only_response(pred) 138 | answer_dict = re.search(r'\{.*\}', only_response) 139 | if answer_dict is None: 140 | return only_response 141 | 142 | str_dict = answer_dict.group(0) 143 | str_dict = str_dict.replace("'s", "\\'s").replace("'t", "\\'t").replace("s' ", "s\\' ") 144 | str_dict = str_dict.replace("\\\\_", "_").replace("\\_", "_") 145 | try: 146 | answer = eval(str_dict) 147 | except Exception as e: 148 | try: 149 | str_dict = str_dict.replace("}", "'}") 150 | answer = eval(str_dict) 151 | except Exception as e: 152 | answer = only_response 153 | return answer 154 | 155 | def find_json_files(directory): 156 | json_files = [] 157 | for root, dirs, files in os.walk(directory): 158 | for file in files: 159 | if "results" not in file and file.endswith('.json'): 160 | json_files.append(os.path.join(root, file)) 161 | return json_files 162 | 163 | def extract_answer_content(result): 164 | pattern = r"['\"]?answer_content['\"]?\s*:\s*['\"]?([^'\"]*)['\"]?" 165 | if type(result) is not str: 166 | result = str(result) 167 | 168 | match = re.search(pattern, result) 169 | if match: 170 | answer_content = match.group(1) 171 | return answer_content 172 | else: 173 | return result 174 | 175 | 176 | if __name__ == "__main__": 177 | input_directory = r"C:\Users\LIHI\Downloads\base-20240721T055544Z-001\base\MusiQue" 178 | id_to_answers = {} 179 | qa_evaluator = QA("id", input_directory, "results_k.json") 180 | 181 | json_files = find_json_files(input_directory) 182 | print("Found JSON files:") 183 | # print(json_files) 184 | 185 | for file_path in json_files: 186 | model_name = os.path.splitext(os.path.basename(file_path))[0] 187 | print(file_path) 188 | with open(file_path, 'rt') as f: 189 | predictions = json.load(f) 190 | model_answers = qa_evaluator._evaluate_new(predictions, model_name, "sample_idx") 191 | for key,value in model_answers.items(): 192 | if key not in id_to_answers: 193 | id_to_answers[key] = {} 194 | id_to_answers[key][model_name] = value 195 | # id_to_answers.update(model_answers) 196 | 197 | # Print or save the results as needed 198 | print("\nFinal id_to_answers:") 199 | print(id_to_answers) 200 | output_file = os.path.join(input_directory, "results_k.json") 201 | with open(output_file, 'wt') as f: 202 | json.dump(id_to_answers, f) 203 | -------------------------------------------------------------------------------- /scripts/create_various_sets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tiktoken 3 | import random 4 | import wikipedia 5 | import torch 6 | import os 7 | from argparse import ArgumentParser 8 | import openai 9 | import copy 10 | from tqdm import tqdm 11 | 12 | """ -------------GLOBAL VARIABLES--------------- """ 13 | 14 | os.environ['CURL_CA_BUNDLE'] = '' 15 | VOCAB_SIZE_TOKENIZER = 100256 16 | wikipedia_cache = {} 17 | 18 | """ -------------HELPER FUNCTIONS--------------- """ 19 | 20 | 21 | def count_tokens(text, encoding): 22 | return len(encoding.encode(text)) 23 | 24 | 25 | def add_random_tokens_end(paragraph_text, num_tokens_to_add, encoding): 26 | vocab_size = VOCAB_SIZE_TOKENIZER 27 | random_text = "".join(encoding.decode([random.randint(0, vocab_size - 1)]) for _ in range(num_tokens_to_add)) 28 | return paragraph_text + random_text 29 | 30 | 31 | def add_pad_tokens_end(paragraph_text, num_tokens_to_add, token_end): 32 | random_text = "".join(token_end for _ in range(num_tokens_to_add)) 33 | return paragraph_text + random_text 34 | 35 | 36 | def get_wikipedia_page_content(title): 37 | if title in wikipedia_cache: 38 | return wikipedia_cache[title] 39 | 40 | search_results = wikipedia.search(title) 41 | if not search_results: 42 | wikipedia_cache[title] = False 43 | return False 44 | try: 45 | page = wikipedia.page(search_results[0]) 46 | wikipedia_cache[title] = page.content 47 | return page.content 48 | except wikipedia.exceptions.PageError: 49 | try: 50 | page = wikipedia.page(search_results[1]) 51 | wikipedia_cache[title] = page.content 52 | return page.content 53 | except Exception as ex: 54 | wikipedia_cache[title] = False 55 | return False 56 | except wikipedia.exceptions.DisambiguationError as e: 57 | # print(f"Disambiguation page. Options: {e.options}") 58 | try: 59 | page = wikipedia.page(e.options[0]) 60 | wikipedia_cache[title] = page.content 61 | return page.content 62 | except Exception as ex: 63 | wikipedia_cache[title] = False 64 | return False 65 | except wikipedia.exceptions.RedirectError as e: 66 | # print(f"Redirected page. New title: {e.title}") 67 | try: 68 | page = wikipedia.page(e.title) 69 | wikipedia_cache[title] = page.content 70 | return page.content 71 | except Exception as ex: 72 | wikipedia_cache[title] = False 73 | return False 74 | except Exception as e: 75 | wikipedia_cache[title] = False 76 | return False 77 | 78 | 79 | def add_wikipedia_tokens_end(paragraph_text, paragraph_title, num_tokens_to_add, encoding): 80 | wiki_page = get_wikipedia_page_content(paragraph_title) 81 | wiki_page_encoded = encoding.encode(paragraph_text if wiki_page is False else wiki_page) 82 | tokens_to_add = (wiki_page_encoded * ((num_tokens_to_add // len(wiki_page_encoded)) + 1))[:num_tokens_to_add] 83 | added_content = "".join(encoding.decode([t]) for t in tokens_to_add) 84 | return paragraph_text + added_content 85 | 86 | 87 | def add_wikipedia_tokens_wrap(paragraph_text, paragraph_title, num_tokens_to_add_before, num_tokens_to_add_after, 88 | encoding): 89 | wiki_page = get_wikipedia_page_content(paragraph_title) 90 | wiki_page_encoded = encoding.encode(paragraph_text if wiki_page is False else wiki_page) 91 | 92 | # remove start of before cintent since it is frequently match the paragraph text 93 | paragraph_text_encoded_len = len(encoding.encode(paragraph_text)) 94 | num_tokens_to_add_before += paragraph_text_encoded_len 95 | 96 | tokens_to_add_before = (wiki_page_encoded * ((num_tokens_to_add_before // len(wiki_page_encoded)) + 1))[ 97 | :num_tokens_to_add_before] 98 | tokens_to_add_after = (wiki_page_encoded * ((num_tokens_to_add_after // len(wiki_page_encoded)) + 1))[ 99 | :num_tokens_to_add_after] 100 | 101 | added_content_before = "".join(encoding.decode([t]) for t in tokens_to_add_before[paragraph_text_encoded_len:]) 102 | added_content_after = "".join(encoding.decode([t]) for t in tokens_to_add_after) 103 | 104 | return added_content_before + paragraph_text + added_content_after 105 | 106 | 107 | def truncate_or_pad_text(text, target_tokens, encoding, add_random=False): 108 | current_tokens = encoding.encode(text) 109 | current_length = len(current_tokens) 110 | if current_length > target_tokens: 111 | truncated_text = encoding.decode(current_tokens[:target_tokens]) 112 | return truncated_text 113 | elif current_length < target_tokens: 114 | if add_random: 115 | num_extra_tokens = target_tokens - current_length 116 | return add_random_tokens_end(text, num_extra_tokens, encoding) 117 | else: 118 | return truncate_or_pad_text(text + text, target_tokens, encoding, add_random=False) 119 | else: 120 | return text 121 | 122 | 123 | def create_rephrased_questions(input_path, output_path): 124 | demonstration = "Original question: What's #1 's hockey club named? " \ 125 | "Rephrased question: What is the hockey club called for the team ranked number one?" 126 | 127 | with open(input_path, 'r') as infile: 128 | total_lines = sum(1 for line in infile) 129 | 130 | with open(input_path, 'r') as infile, open(output_path, 'w') as outfile: 131 | for i, line in enumerate(tqdm(infile, total=total_lines, desc="Processing lines")): 132 | data = json.loads(line) 133 | original_question = data["question"] 134 | prompt = f"Given a following question, rephrase it to maintain the exact idea but change the phrasing as much as possible." \ 135 | f" For example: {demonstration} Rephrase the original question: {original_question} Rephrased question:" 136 | 137 | completion = openai.chat.completions.create( 138 | model="gpt-4", 139 | messages=[ 140 | { 141 | "role": "user", 142 | "content": prompt, 143 | }, 144 | ], 145 | ) 146 | data["original_question"] = original_question 147 | 148 | data["question"] = completion.choices[0].message.content 149 | outfile.write(json.dumps(data) + '\n') 150 | 151 | 152 | def rephrase_question_in_set(rephrased_path, input_path, output_path): 153 | with open(rephrased_path, 'r') as rffile, open(input_path, 'r') as infile, open(output_path, 'w') as outfile: 154 | for rf_line, in_line in zip(rffile, infile): 155 | data = json.loads(in_line) 156 | reference = json.loads(rf_line) 157 | data["question"] = reference["question"] 158 | outfile.write(json.dumps(data) + '\n') 159 | 160 | 161 | """ -------------CREATE SETS FUNCTIONS--------------- """ 162 | 163 | 164 | def create_original_collection(input_path, output_path): 165 | with open(input_path, 'r') as infile, open(output_path, 'w') as outfile: 166 | for line in infile: 167 | data = json.loads(line) 168 | if data.get('answerable') is True: 169 | outfile.write(json.dumps(data) + '\n') 170 | 171 | 172 | def create_oracle(input_path, output_path): 173 | with open(input_path, 'r') as infile, open(output_path, 'w') as outfile: 174 | for line in infile: 175 | data = json.loads(line) 176 | supporting_paragraphs = [p for p in data.get('paragraphs', []) if p.get('is_supporting') is True] 177 | if supporting_paragraphs: 178 | # Replace the paragraphs with only the supporting ones 179 | data['paragraphs'] = supporting_paragraphs 180 | outfile.write(json.dumps(data) + '\n') 181 | 182 | 183 | def create_no_questions(input_path, output_path): 184 | with open(input_path, 'r') as infile, open(output_path, 'w') as outfile: 185 | for line in infile: 186 | data = json.loads(line) 187 | data['paragraphs'] = [] 188 | outfile.write(json.dumps(data) + '\n') 189 | 190 | 191 | def create_replace_distractors(input_path, output_path, num_to_retain=None): 192 | encoding = tiktoken.encoding_for_model("gpt-4") 193 | 194 | with open(input_path, 'r') as infile: 195 | lines = [json.loads(line) for line in infile] 196 | 197 | all_non_supporting_paragraphs = [] 198 | for line in lines: 199 | non_supporting_paragraphs = [p for p in line.get('paragraphs', []) if not p.get('is_supporting')] 200 | all_non_supporting_paragraphs.extend(non_supporting_paragraphs) 201 | 202 | with open(output_path, 'w') as outfile: 203 | for line in lines: 204 | supporting_paragraphs = [p for p in line.get('paragraphs', []) if p.get('is_supporting')] 205 | non_supporting_paragraphs = [p for p in line.get('paragraphs', []) if not p.get('is_supporting')] 206 | 207 | if num_to_retain is not None and num_to_retain < len(non_supporting_paragraphs): 208 | retained_paragraphs = non_supporting_paragraphs[:num_to_retain] 209 | paragraphs_to_replace = non_supporting_paragraphs[num_to_retain:] 210 | else: 211 | retained_paragraphs = [] 212 | paragraphs_to_replace = non_supporting_paragraphs 213 | 214 | for p in paragraphs_to_replace: 215 | target_tokens = count_tokens(p['paragraph_text'], encoding) 216 | new_paragraph = random.choice(all_non_supporting_paragraphs) 217 | new_paragraph_text = truncate_or_pad_text(new_paragraph['paragraph_text'], target_tokens, encoding) 218 | p['paragraph_text'] = new_paragraph_text 219 | 220 | data = { 221 | 'id': line['id'], 222 | 'paragraphs': supporting_paragraphs + retained_paragraphs + paragraphs_to_replace, 223 | 'question': line['question'], 224 | 'question_decomposition': line['question_decomposition'], 225 | 'answer': line['answer'], 226 | 'answer_aliases': line['answer_aliases'], 227 | 'answerable': line['answerable'] 228 | } 229 | 230 | outfile.write(json.dumps(data) + '\n') 231 | 232 | 233 | def create_expanded(input_path, output_path, num_to_retain=None): 234 | encoding = tiktoken.encoding_for_model("gpt-4") 235 | 236 | with open(input_path, 'r') as infile: 237 | total_lines = sum(1 for line in infile) 238 | 239 | with open(input_path, 'r') as infile, open(output_path, 'w') as outfile: 240 | for i, line in enumerate(tqdm(infile, total=total_lines, desc="Processing lines")): 241 | data = json.loads(line) 242 | supporting_paragraphs = [p for p in data.get('paragraphs', []) if p.get('is_supporting') is True] 243 | non_supporting_paragraphs = [p for p in data.get('paragraphs', []) if p.get('is_supporting') is False] 244 | 245 | if num_to_retain is not None: 246 | cur_num_to_retain = min(num_to_retain - len(supporting_paragraphs), len(non_supporting_paragraphs)) 247 | random_non_supporting = random.sample(non_supporting_paragraphs, cur_num_to_retain) 248 | supporting_paragraphs.extend(random_non_supporting) 249 | supporting_paragraphs.sort(key=lambda p: p['idx']) 250 | non_supporting_paragraphs = [p for p in non_supporting_paragraphs if p not in random_non_supporting] 251 | 252 | non_supporting_paragraphs_tokens = [[p['idx'], count_tokens(p['paragraph_text'], encoding)] 253 | for p in non_supporting_paragraphs] 254 | 255 | num_supporting_paragraphs = len(supporting_paragraphs) 256 | 257 | if num_supporting_paragraphs > 0: 258 | for i, p in enumerate(supporting_paragraphs): 259 | if i == 0: 260 | num_tokens_to_add_before = \ 261 | sum([tokens[1] for tokens in non_supporting_paragraphs_tokens if tokens[0] < p['idx']]) 262 | else: 263 | num_tokens_to_add_before = \ 264 | sum([tokens[1] for tokens in non_supporting_paragraphs_tokens 265 | if p['idx'] > tokens[0] > supporting_paragraphs[i - 1]['idx']]) / 2 266 | 267 | if i == len(supporting_paragraphs) - 1: 268 | num_tokens_to_add_after = \ 269 | sum([tokens[1] for tokens in non_supporting_paragraphs_tokens if tokens[0] > p['idx']]) 270 | else: 271 | num_tokens_to_add_after = \ 272 | sum([tokens[1] for tokens in non_supporting_paragraphs_tokens 273 | if supporting_paragraphs[i + 1]['idx'] > tokens[0] > p['idx']]) / 2 274 | 275 | p['paragraph_text'] = add_wikipedia_tokens_wrap(p['paragraph_text'], p['title'], 276 | int(num_tokens_to_add_before), 277 | int(num_tokens_to_add_after), 278 | encoding) 279 | data['paragraphs'] = supporting_paragraphs 280 | outfile.write(json.dumps(data) + '\n') 281 | 282 | 283 | if __name__ == '__main__': 284 | parser = ArgumentParser() 285 | parser.add_argument('--input', type=str, required=True, help='Path to the input file') 286 | parser.add_argument('--output', type=str, required=True, help='Path to the output file') 287 | parser.add_argument('--dataset_name', type=str, required=True, 288 | choices=['original', 'noQuestion', 'expanded', 'replaced', 'oracle', 'hybrid', 'rephrase_questions'], 289 | help='Name of the dataset. Must be one of: original, noQuestion, expanded, replaced, oracle, hybrid, rephrase_questions') 290 | parser.add_argument('--num_of_documents', type=int, default=None, help='Number of documents to include on dataset, used in hybrid') 291 | args = parser.parse_args() 292 | 293 | input_path = args.input 294 | output_path = args.output 295 | dataset = args.dataset_name 296 | 297 | if dataset == 'original': 298 | # this set takes the answerable questions only from the original MuSique dataset 299 | create_original_collection(input_path, output_path) 300 | elif dataset == 'noQuestion': 301 | # this set takes only the questions, without the documents 302 | create_no_questions(input_path, output_path) 303 | elif dataset == 'expanded': 304 | # this set takes only the supporting documents 305 | # and add the Wikipedia page content to match the original token count, 306 | # while keeping the original document inforamtion in the same place 307 | create_expanded(input_path, output_path) 308 | elif dataset == 'replaced': 309 | # this set replace the non-supporting documents with other's instances documents 310 | create_replace_distractors(input_path, output_path) 311 | elif dataset == 'oracle': 312 | # this set takes only the supporting documents 313 | create_oracle(input_path, output_path) 314 | elif dataset == 'hybrid': 315 | # this set is similar to the "expanded", but remaining more documents, as the num_of_documents provided 316 | create_replace_distractors(input_path, output_path, args.num_of_documents) 317 | elif dataset == 'rephrase_questions': 318 | # this set rephrased the question using GPT-4, the documents remain the same as in input 319 | create_rephrased_questions(input_path, output_path) 320 | else: 321 | raise ValueError(f"Unknown dataset name: {dataset}") 322 | 323 | print(f"Saving {args.dataset_name} dataset in to {output_path}") 324 | 325 | -------------------------------------------------------------------------------- /scripts/eval_all.sh: -------------------------------------------------------------------------------- 1 | while [[ $# -gt 0 ]]; do 2 | key="$1" 3 | case $key in 4 | -p|--predictions_path) 5 | predictions_path="$2" 6 | shift 2 7 | ;; 8 | -o|--out_dir) 9 | out_dir="$2" 10 | shift 2 11 | ;; 12 | *) 13 | echo "Error: Unknown option: $1" 14 | usage 15 | exit 1 16 | ;; 17 | esac 18 | done 19 | 20 | export PYTHONPATH=./ 21 | 22 | echo ${predictions_path} 23 | echo ${out_dir} 24 | 25 | datasets=("ECB" "SciCo" "MultiNews" "MusiQue" "OpenASP" "FuseReviews") 26 | 27 | for ds in "${datasets[@]}"; do 28 | echo "Evaluating ${ds}" 29 | python scripts/evaluate_dataset.py --ds_name "${ds}" --predictions_dir "${predictions_path}/${ds}" --output_path "${out_dir}/${ds}/results.json" 30 | python scripts/evaluation/statistical_analysis.py --results_path "${out_dir}/${ds}/results.json" 31 | done 32 | 33 | models=("Mistral-7B" "Llama3-8B" "Llama3-70B" "Mixtral-8x7B" "Mixtral-8x22B" "Gemma1.1-2B" "Gemma1.1-7B") 34 | 35 | for model in "${models[@]}"; do 36 | echo "Evaluating ${model}" 37 | python scripts/evaluation/sample_length_function.py --results_dir "${out_dir}" --model_name "${model}" 38 | done 39 | 40 | python scripts/evaluation/compare_models.py --results_dir "${out_dir}" -------------------------------------------------------------------------------- /scripts/evaluate_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from argparse import ArgumentParser 3 | from evaluation_classes.fuse_reviews import FuseReviews 4 | from evaluation_classes.question_answering import QA 5 | from evaluation_classes.coreference import Coref 6 | from evaluation_classes.summarization import Summarization 7 | 8 | 9 | def main(predictions_dir, output_path, ds_name): 10 | if ds_name == 'FuseReviews': 11 | eval_class = FuseReviews("guid", predictions_dir, output_path) 12 | elif ds_name == 'MusiQue': 13 | eval_class = QA("id", predictions_dir, output_path) 14 | elif ds_name == 'ECB': 15 | eval_class = Coref("topic_id", predictions_dir, output_path) 16 | elif ds_name == 'SciCo': 17 | eval_class = Coref("sample_id", predictions_dir, output_path) 18 | elif ds_name == 'MultiNews': 19 | eval_class = Summarization("id", predictions_dir, output_path) 20 | elif ds_name == 'OpenASP': 21 | eval_class = Summarization("guid", predictions_dir, output_path) 22 | else: 23 | raise ValueError(f"Unknown dataset name: {ds_name}") 24 | eval_class.eval_all_in_dir() 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = ArgumentParser() 29 | parser.add_argument("--predictions_dir", type=str, required=True) 30 | parser.add_argument("--output_path", type=str, required=True) 31 | parser.add_argument("--ds_name", type=str, required=True) 32 | args = parser.parse_args() 33 | dir_out = os.path.dirname(args.output_path) 34 | os.makedirs(dir_out, exist_ok=True) 35 | main(args.predictions_dir, args.output_path, args.ds_name) 36 | -------------------------------------------------------------------------------- /scripts/evaluation/ASD.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from scipy.stats import norm as normal 3 | from scipy.stats import mannwhitneyu as Utest 4 | import numpy as np 5 | # import matplotlib.pyplot as plt 6 | 7 | 8 | F = [] 9 | G = [] 10 | n = 0 11 | m = 0 12 | def buildOrigCDFs(f, g): 13 | global F 14 | global G 15 | global n 16 | global m 17 | F = np.sort(f) 18 | n = len(F) 19 | G = np.sort(g) 20 | m = len(G) 21 | 22 | 23 | def buildNewCDFs(f, g): 24 | global Fb 25 | global Gb 26 | Fb = np.sort(f) 27 | Gb = np.sort(g) 28 | 29 | 30 | def invG(p): 31 | index = int(np.ceil(p*m)) 32 | if index >= m: 33 | return G[m-1] 34 | elif index == 0: 35 | return G[0] 36 | return G[index-1] 37 | 38 | 39 | def invF(p): 40 | index = int(np.ceil(p*n)) 41 | if index >= n: 42 | return F[n-1] 43 | elif index == 0: 44 | return F[0] 45 | return F[index-1] 46 | 47 | 48 | def invGnew(p, M): 49 | index = int(np.ceil(p*M)) 50 | if index >= M: 51 | return Gb[M-1] 52 | elif index == 0: 53 | return Gb[0] 54 | return Gb[index-1] 55 | 56 | 57 | def invFnew(p, N): 58 | index = int(np.ceil(p*N)) 59 | if index >= N: 60 | return Fb[N-1] 61 | elif index == 0: 62 | return Fb[0] 63 | return Fb[index-1] 64 | 65 | 66 | def epsilon(dp): 67 | s = 0.0 68 | se = 0.0 69 | for p in np.arange(0, 1, dp): 70 | temp = invG(p)-invF(p) 71 | tempe = max(temp, 0) 72 | s = s+temp*temp*dp 73 | se = se+tempe*tempe*dp 74 | if s != 0: 75 | return se/s 76 | else: 77 | print("The denominator is 0") 78 | return 0.0 79 | 80 | 81 | def epsilonNew(dp, N, M): 82 | denom = 0.0 83 | numer = 0.0 84 | for p in np.arange(0, 1, dp): 85 | diff = invGnew(p, M) - invFnew(p, N) # check when F-1(t) 20 111 | if n<20 or m<20: 112 | print("Use only when the number of observation in each sample is > 20") 113 | return 1.0 114 | _, pval = Utest(data_A, data_B, alternative='less') 115 | return pval 116 | 117 | 118 | ############################################################## 119 | def calc_dominance(data_A, data_B, alpha, name_A, name_B, out_file): 120 | 121 | buildOrigCDFs(data_A, data_B) 122 | 123 | # constants 124 | dp = 0.005 # differential of the variable p - for integral calculations 125 | N = 1000 # num of samples from F for sigma estimate 126 | M = 1000 # num of samples from G for sigma estimate 127 | B = 1000 # bootstrap iterations for sigma estimate 128 | 129 | # calculate the epsilon quotient 130 | eps_FnGm = epsilon(dp) 131 | 132 | # estimate the variance 133 | lamda = (0.0 + N) / (N + M) 134 | const = np.sqrt((1.0 * N * M) / (N + M + 0.0)) 135 | samples = [] 136 | for b in range(B): 137 | Fb = [] 138 | Gb = [] 139 | Fvalues = [] 140 | Gvalues = [] 141 | uniF = np.random.uniform(0, 1, N) 142 | uniG = np.random.uniform(0, 1, M) 143 | for i in range(0, N): 144 | Fvalues.append(invF(uniF[i])) 145 | for j in range(0, M): 146 | Gvalues.append(invG(uniG[j])) 147 | buildNewCDFs(Fvalues, Gvalues) 148 | distance = epsilonNew(dp, N, M) 149 | samples.append(distance) 150 | 151 | sigma = np.std(samples) 152 | 153 | min_epsilon = min(max(eps_FnGm - (1/const) * sigma * normal.ppf(alpha), 0.0), 1.0) 154 | print(f"The minimal epsilon for which {name_A} is almost " 155 | f"stochastically greater than {name_B} is ", min_epsilon, file=out_file) 156 | if min_epsilon <= 0.5 and min_epsilon > 0.0: 157 | print(f"since epsilon <= 0.5 we will claim that {name_A} is " 158 | f"better than {name_B} with significance level alpha=", alpha, file=out_file) 159 | elif min_epsilon == 0.0: 160 | print(f'since epsilon = 0, {name_A} is stochatically dominant over {name_B}', file=out_file) 161 | 162 | else: 163 | print(f"since epsilon > 0.5 we will claim that {name_A} " 164 | f"is not better than {name_B} with significance level alpha=", alpha, file=out_file) 165 | return min_epsilon 166 | 167 | # print(MannWhitney(data_A, data_B) 0: 32 | models[model_name]['normalized_var'].append(model_variance/model_mean) 33 | dict_std_by_task[task].append(model_variance/model_mean) 34 | print(f_name) 35 | out_dict_rank = {} 36 | out_dict_std = {} 37 | for model_name in models: 38 | out_dict_rank[model_name] = {} 39 | out_dict_std[model_name] = {} 40 | out_dict_rank[model_name]['Rank'] = np.mean(models[model_name]['ranks']) 41 | out_dict_std[model_name]['Standard Deviation'] = np.mean(models[model_name]['normalized_var'])*100 42 | out_dict_std_by_task = {task: {} for task in dict_std_by_task} 43 | for task in out_dict_std_by_task: 44 | out_dict_std_by_task[task]['Standard Deviation'] = np.mean(dict_std_by_task[task])*100 45 | 46 | # COLORS.reverse() 47 | output = pd.DataFrame(out_dict_rank) 48 | output = output[sorted(output.columns)] 49 | plot(sorted_output=output, COLORS=COLORS, x_label="Average rank (AR ↓)\nacross models\n", 50 | figs_dir=results_dir, fig_name='ranking') 51 | output = pd.DataFrame(out_dict_std) 52 | output = output[sorted(output.columns)] 53 | plot(sorted_output=output, COLORS=COLORS, x_label="Average relative standard\ndeviation (ARSD ↓) across\nmodels", 54 | figs_dir=results_dir, fig_name='rsd_for_model') 55 | output = pd.DataFrame(out_dict_std_by_task) 56 | output = output[reversed(tasks_order)] 57 | plot(sorted_output=output, COLORS=None, x_label="Average relative standard\ndeviation (ARSD ↓) across\ndatasets", 58 | figs_dir=results_dir, fig_name='rsd_for_task') 59 | 60 | 61 | def plot(sorted_output, COLORS, x_label, figs_dir, fig_name): 62 | plt.figure(fig_name) 63 | plt.barh(sorted_output.columns, sorted_output.iloc[0], color=COLORS) 64 | if fig_name != 'ranking': 65 | plt.xticks(np.arange(5, 25, 5), fontsize=19) 66 | else: 67 | plt.xticks(np.arange(1, 8, 1), fontsize=19) 68 | plt.xlabel(x_label, fontsize=20) 69 | plt.yticks(fontsize=20) 70 | plt.tight_layout() 71 | plt.savefig(f'{figs_dir}/{fig_name}.pdf') 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = ArgumentParser() 76 | parser.add_argument("--results_dir", type=str, required=True) 77 | args = parser.parse_args() 78 | main(args.results_dir) -------------------------------------------------------------------------------- /scripts/evaluation/sample_length_function.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import json 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from scipy.stats import hmean 7 | import matplotlib.pyplot as plt 8 | from scipy.interpolate import CubicSpline 9 | from sklearn.metrics import r2_score 10 | 11 | from scripts.evaluation.compare_models import tasks_order 12 | TOKENS_TRUNCATED = 7500 13 | # LABELS = {'rouge1': 'R-1', 'rouge2': 'R-2', 'rougeL': 'R-L', 'faithfulness_score': 'faithfulness', 14 | # 'coverage_score': 'coverage', 'all_f1': 'F1', 'Roug'} 15 | 16 | COLOR_MAP = {'FuseReviews': 'blue', 'MusiQue': 'red', 'ECB': 'green', 'SciCo': 'purple', 'MultiNews': 'orange', 'OpenASP': 'brown'} 17 | MARKER_MAP = {'FuseReviews': ',', 'MusiQue': '*', 'MultiNews': '.', 'OpenASP': '+'} 18 | 19 | 20 | def eval_length(res_dir, model_name): 21 | deg = 3 22 | # plt.figure(figsize=(9, 6)) 23 | for task in tasks_order: 24 | task_dir = os.path.join(res_dir, task) 25 | with open(os.path.join(task_dir, "results.json"), 'rt') as f: 26 | results = json.load(f) 27 | scores = results["models"][model_name]["scores"] 28 | metrics = scores[0].keys() 29 | mean = [] 30 | for metric in metrics: 31 | metric_scores = np.array([score[metric] for score in scores]) 32 | mean.append(metric_scores) 33 | if task in ["OpenASP", "MultiNews"]: 34 | # geometric mean for rouge 35 | mean = np.prod(mean, axis=0) ** (1 / 3) 36 | name = 'ROUGE' 37 | scores = [{name: mean[i]} for i in range(len(mean))] 38 | elif task == 'FuseReviews': 39 | mean = hmean(mean, axis=0) 40 | name = 'F1' 41 | scores = [{name: mean[i]} for i in range(len(mean))] 42 | elif task == 'MusiQue': 43 | scores = [{'F1': s['all_f1']} for s in scores] 44 | metrics = scores[0].keys() 45 | lengths = results["sample_lengths"] 46 | sorted_by_lengths = sorted(lengths.keys(), key=lambda x: lengths[x]) 47 | sorted_lengths = np.array(sorted(lengths.values())) 48 | valid_indices = np.where(np.array(sorted_lengths) <= TOKENS_TRUNCATED)[0] 49 | for metric in metrics: 50 | if metric.lower() == 'conll f1': 51 | continue 52 | mean_scores_per_sample = [] 53 | stds = [] 54 | metric_scores = np.array([score[metric] for score in scores]) 55 | # normalize scores to be in [0, 100] 56 | metric_scores = (metric_scores - np.min(metric_scores)) / (np.max(metric_scores) - np.min(metric_scores)) 57 | if metric_scores.ndim == 1: 58 | continue 59 | for sample_id in sorted_by_lengths: 60 | total_for_sample = [] 61 | for run_index in range(len(metric_scores)): 62 | if sample_id in results["models"][model_name]['ids'][run_index]: 63 | index = results["models"][model_name]['ids'][run_index].index(sample_id) 64 | score = metric_scores[run_index, index] 65 | total_for_sample.append(score) 66 | mean_scores_per_sample.append(np.mean(total_for_sample)) 67 | stds.append(np.std(total_for_sample)) 68 | # sampled_valid_indices = np.random.choice(valid_indices, 100) 69 | valid_lengths = np.array(sorted_lengths)[valid_indices] 70 | args_sorted = np.argsort(valid_lengths) 71 | valid_lengths = valid_lengths[args_sorted] 72 | valid_scores = np.array(mean_scores_per_sample)[valid_indices] 73 | valid_scores = valid_scores[args_sorted] 74 | 75 | # remove duplicates from valid lengths, and mean over the duplicates scores in valid_scores 76 | # unique_lengths = np.unique(valid_lengths) 77 | # unique_scores = [] 78 | # for length in unique_lengths: 79 | # indices = np.where(valid_lengths == length) 80 | # unique_scores.append(np.mean(valid_scores[indices])) 81 | 82 | 83 | 84 | # Assuming you have your data points in separate lists named 'x' and 'y' 85 | 86 | # # Plot the data points 87 | # plt.plot(unique_lengths, unique_scores, 'o', label=f'{task} {metric}') 88 | # 89 | # # Perform cubic spline interpolation 90 | # spline = CubicSpline(unique_lengths, unique_scores) 91 | # 92 | # # Generate smoother curve using spline interpolation 93 | # smooth_x = np.linspace(min(unique_lengths), max(unique_lengths), 100) # Adjust number of points for smoothness 94 | # smooth_y = spline(smooth_x) 95 | # 96 | # # Plot the spline curve 97 | # plt.plot(smooth_x, smooth_y, label='Spline Interpolation') 98 | # 99 | # # Calculate and display R-squared 100 | # y_pred = spline(unique_lengths) # Predicted values using spline 101 | # r2 = r2_score(unique_scores, y_pred) 102 | # print(f"R-squared: {r2}") 103 | # 104 | # # Add labels and title 105 | # plt.xlabel("Sample length (# tokens)") 106 | # plt.ylabel("Score") 107 | # plt.title("Data with Spline Interpolation") 108 | # plt.legend() 109 | # plt.grid(True) 110 | # plt.show() 111 | 112 | 113 | 114 | 115 | coefficients = np.polyfit(valid_lengths, valid_scores, deg=deg) 116 | polynomial = np.poly1d(coefficients) 117 | 118 | # Generate fitted line 119 | x_fit = np.linspace(min(valid_lengths), max(valid_lengths), 100) 120 | y_fit = polynomial(x_fit) 121 | 122 | label = f'{task} {metric}' 123 | 124 | # Plot fitted line 125 | # plt.plot(x_fit, y_fit, color=COLOR_MAP[task], alpha=0.5) 126 | plt.scatter(valid_lengths, valid_scores, label=label, alpha=0.5, marker=MARKER_MAP[task], color=COLOR_MAP[task]) 127 | # plt.show() 128 | # plt.title("Data with Spline Interpolation") 129 | # plt.legend() 130 | # plt.grid(True) 131 | # plt.show() 132 | 133 | 134 | 135 | 136 | 137 | plt.xlabel("Sample length (# tokens)", fontsize=15) 138 | plt.ylabel("Relative score", fontsize=15) 139 | plt.xticks(fontsize=14) 140 | plt.legend(loc='upper right') 141 | plt.tight_layout() 142 | 143 | path = os.path.join(res_dir, f"polyfit_{model_name}.pdf") 144 | # Show plot 145 | plt.savefig(path) 146 | # plt.show() 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = ArgumentParser() 151 | parser.add_argument("--results_dir", type=str, required=True) 152 | parser.add_argument("--model_name", type=str, required=True) 153 | args = parser.parse_args() 154 | eval_length(args.results_dir, args.model_name) -------------------------------------------------------------------------------- /scripts/evaluation/statistical_analysis.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path 3 | import sys 4 | from argparse import ArgumentParser 5 | import json 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | from ASD import calc_dominance 10 | from itertools import product 11 | import seaborn as sns 12 | from scipy.stats import hmean 13 | 14 | TOKENS_TRUNCATED = 7500 15 | COLORS = ['plum', 'mediumpurple', 'cadetblue', 'cornflowerblue', 'coral', 'palevioletred', 'peru'] 16 | 17 | def analyze_results(results_path): 18 | dataset_name = os.path.dirname(results_path).split("/")[-1] 19 | out_dir = os.path.dirname(results_path) 20 | figs_dir = os.path.join(out_dir, "figs") 21 | scores_path = os.path.join(out_dir, "scores.txt") 22 | scores_file = open(scores_path, "w") 23 | os.makedirs(figs_dir, exist_ok=True) 24 | with open(results_path, 'r') as f: 25 | results = json.load(f) 26 | lengths = results["sample_lengths"] 27 | sorted_by_lengths = sorted(lengths.keys(), key=lambda x: lengths[x]) 28 | sorted_lengths = np.array(sorted(lengths.values())) 29 | # remove outlier lengths 30 | valid_indices = np.where(np.array(sorted_lengths) <= TOKENS_TRUNCATED)[0] 31 | total = {} 32 | vectors = {} 33 | for_box_plot = {} 34 | for model in results["models"]: 35 | plt.figure() 36 | print(f"Model: {model}", file=scores_file) 37 | scores = results["models"][model]["scores"] 38 | metrics = scores[0].keys() 39 | total[model] = {} 40 | if len(list(metrics)) > 1: 41 | # calc geometric mean per sample 42 | mean = [] 43 | for metric in metrics: 44 | metric_scores = np.array([score[metric] for score in scores]) 45 | mean.append(metric_scores) 46 | if 'rouge' in list(metrics)[0].lower(): 47 | # geometric mean for rouge 48 | mean = np.prod(mean, axis=0) ** (1 / len(metrics)) * 100 49 | name = 'ROUGE' 50 | elif 'faithfulness' in list(metrics)[0].lower(): 51 | mean = hmean(mean, axis=0) 52 | name = 'F1' 53 | else: 54 | print() 55 | raise ValueError("Unknown metric") 56 | metrics = [name] 57 | scores = [{name: mean[i]} for i in range(len(mean))] 58 | results["models"][model]["scores"] = scores 59 | for metric in metrics: 60 | if metric not in for_box_plot: 61 | for_box_plot[metric] = {'models': [], 'scores': []} 62 | metric_scores = np.array([score[metric] for score in scores]) 63 | if metric_scores.ndim == 1: 64 | metric_scores = metric_scores[:, None] 65 | mean_per_run = np.mean(metric_scores, axis=1) 66 | if metric not in vectors: 67 | vectors[metric] = {} 68 | vectors[metric][model] = mean_per_run 69 | std_between_runs = np.std(mean_per_run) 70 | total_mean = np.mean(metric_scores) 71 | # if total_mean < 1: 72 | # total_mean *= 100 73 | # std_between_runs *= 100 74 | total[model][metric] = (total_mean, std_between_runs) 75 | for_box_plot[metric]['models'].append(model) 76 | for_box_plot[metric]['scores'].append(mean_per_run) 77 | print(f"\t{metric}: {total_mean} +/- {std_between_runs}", file=scores_file) 78 | if metric_scores.shape[1] == 1: 79 | continue 80 | mean_scores_per_sample = [] 81 | stds = [] 82 | for sample_id in sorted_by_lengths: 83 | total_for_sample = [] 84 | for run_index in range(len(metric_scores)): 85 | if sample_id in results["models"][model]['ids'][run_index]: 86 | index = results["models"][model]['ids'][run_index].index(sample_id) 87 | score = metric_scores[run_index, index] 88 | total_for_sample.append(score) 89 | mean_scores_per_sample.append(np.mean(total_for_sample)) 90 | stds.append(np.std(total_for_sample)) 91 | # plt.errorbar(sorted_lengths, mean_scores_per_sample, yerr=stds, label=metric) 92 | plt.scatter(sorted_lengths[valid_indices], np.array(mean_scores_per_sample)[valid_indices], label=metric) 93 | plt.title(f"{model} scores") 94 | plt.xlabel("Sample length") 95 | plt.ylabel("Score") 96 | plt.legend() 97 | plt.savefig(os.path.join(figs_dir, f'{model.replace("/", "_")}_scores.pdf')) 98 | f.close() 99 | 100 | 101 | for metric in for_box_plot: 102 | plt.figure(figsize=(6.4, 3.8)) 103 | argsort_by_model_name = np.argsort(for_box_plot[metric]['models']) 104 | for_box_plot[metric]['models'] = [for_box_plot[metric]['models'][i] for i in argsort_by_model_name] 105 | for_box_plot[metric]['scores'] = [for_box_plot[metric]['scores'][i] for i in argsort_by_model_name] 106 | name = metric 107 | if metric == 'conll_F1': 108 | name = 'CoNLL F1' 109 | plt.xlim(-1, 34) 110 | elif metric == 'ROUGE': 111 | plt.xlim(-1, 24) 112 | plt.xticks(np.arange(0, 24, 2)) 113 | else: 114 | for_box_plot[metric]['scores'] = [np.array(scores) * 100 for scores in for_box_plot[metric]['scores']] 115 | bplot = plt.boxplot(for_box_plot[metric]['scores'], labels=for_box_plot[metric]['models'], 116 | vert=False, patch_artist=True) 117 | 118 | # fill with colors 119 | for patch, color in zip(bplot['boxes'], COLORS): 120 | patch.set_facecolor(color) 121 | 122 | for line in bplot['medians']: 123 | line.set_color("black") 124 | 125 | if dataset_name == 'ECB': 126 | dataset_name = 'ECB+' 127 | xlabel = f"{dataset_name} {name}" 128 | plt.xlabel(xlabel, fontsize=14) 129 | plt.yticks(fontsize=14) 130 | plt.xticks(fontsize=13) 131 | plt.tight_layout() 132 | plt.savefig(os.path.join(figs_dir, f"box_plot_{dataset_name}_{metric}.pdf")) 133 | 134 | total_df = pd.DataFrame(total).T 135 | total_df.to_csv(os.path.join(out_dir, "total_scores.csv")) 136 | asd_file = open(os.path.join(out_dir, "asd.txt"), "w") 137 | asd_table = os.path.join(out_dir, "ranking_asd.csv") 138 | sorted_model_names = sorted(total.keys()) 139 | for metric in vectors: 140 | asd_file.write(f"\n\n===== {metric} =====\n\n") 141 | tab = {} 142 | visited = set() 143 | for a, b in product(total, total): 144 | if a == b or (b, a) in visited: 145 | continue 146 | dom = calc_dominance(data_A=vectors[metric][a], data_B=vectors[metric][b], alpha=0.05, 147 | name_A=a, name_B=b, out_file=asd_file) 148 | if a not in tab: 149 | tab[a] = {} 150 | if b not in tab: 151 | tab[b] = {} 152 | tab[a][b] = round(dom, 2) 153 | tab[b][a] = 1 - round(dom, 2) 154 | visited.add((a, b)) 155 | asd_file.write("\n") 156 | plt.figure() 157 | sum_dom = {k: sum(tab[k].values()) for k in tab} 158 | ranking_model_names = sorted(sum_dom, key=lambda x: sum_dom[x]) 159 | one_minus_alpha = [1 - tab[ranking_model_names[i]][ranking_model_names[i+1]] for i in range(len(ranking_model_names)-1)] 160 | one_minus_alpha.append(None) 161 | ranking_df = pd.DataFrame(columns=["model_name", "dominance"]) 162 | ranking_df["model_name"] = ranking_model_names 163 | ranking_df["dominance"] = one_minus_alpha 164 | scores = {model: (round(total_df.loc[model, metric][0] * 100, 1), round(total_df.loc[model, metric][1] * 100, 1)) 165 | for model in ranking_model_names} 166 | sorted_scores = [scores[model] for model in ranking_model_names] 167 | ranking_df["score"] = sorted_scores 168 | ranking_df.to_csv(asd_table) 169 | asd_file.close() 170 | 171 | 172 | if __name__ == '__main__': 173 | parser = ArgumentParser() 174 | parser.add_argument("--results_path", type=str, required=True) 175 | args = parser.parse_args() 176 | analyze_results(args.results_path) 177 | 178 | -------------------------------------------------------------------------------- /scripts/generate_benchmark.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from argparse import ArgumentParser 3 | import json 4 | from datasets_classes.base import Dataset 5 | from datasets_classes.qa.MuSiQue import MusiQue 6 | import random 7 | import pickle as pkl 8 | from datasets import load_dataset 9 | 10 | def read_metadata(metadata_path): 11 | with open(metadata_path, 'r') as f: 12 | meta = json.load(f) 13 | fill_default_arguments(meta) 14 | return meta 15 | 16 | 17 | def fill_default_arguments(metadata): 18 | metadata['num_demonstrations'] = metadata.get('num_demonstrations', 3) 19 | metadata['max_num_samples'] = metadata.get('max_num_samples', -1) 20 | 21 | 22 | def load_all_datasets(num_demos, l_datasets, num_runs, max_num_samples, random_seed): 23 | Dataset.update_common_data("num_demos", num_demos) 24 | Dataset.update_common_data("max_num_samples", max_num_samples) 25 | Dataset.update_common_data("random", random.Random(random_seed)) 26 | all_names = set() 27 | for ds_instance in l_datasets: 28 | name = ds_instance['name'] 29 | split = ds_instance['split_name'] 30 | all_names.add(name) 31 | path = ds_instance.get('path') 32 | print(f"Loading dataset {name}") 33 | ds_instance['instances'] = [] 34 | for i in range(num_runs): 35 | ds = MusiQue(name, path, split) 36 | ds_instance["instances"].append(ds) 37 | print(f"Loaded all datasets_classes:\n{all_names} * {num_runs} runs.") 38 | return l_datasets 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = ArgumentParser() 43 | parser.add_argument('--config', required=True) 44 | args = parser.parse_args() 45 | metadata_dict = read_metadata(args.config) 46 | out_dir = metadata_dict['out_dir'] 47 | pickle_out = os.path.join(out_dir, f"datasets_{metadata_dict['run_name']}.pkl") 48 | print(f"Saving datasets_classes to {pickle_out}") 49 | datasets_list = load_all_datasets(metadata_dict['num_demonstrations'], 50 | metadata_dict['datasets_classes'], 51 | num_runs=metadata_dict['num_different_runs'], 52 | max_num_samples=metadata_dict['max_num_samples'], 53 | random_seed=metadata_dict['random_seed']) 54 | pickle_out = os.path.join(out_dir, f"datasets_{metadata_dict['run_name']}.pkl") 55 | print(f"Saving datasets_classes to {pickle_out}") 56 | with open(pickle_out, 'wb') as f: 57 | pkl.dump(datasets_list, f) 58 | -------------------------------------------------------------------------------- /scripts/graph.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | ######################################################################################################################## 5 | # For gimini 6 | # Data 7 | datasets = ['Baseline', 'Extended', 'Full','Replaced', 'No Doc'] 8 | rephrased = [0.75, 0.80, 0.78, 0.85, 0.82] 9 | without_rephrasing = [0.21, 0.022, 0.01, 0.83, 0.80] 10 | 11 | x = np.arange(len(datasets)) # the label locations 12 | width = 0.35 # the width of the bars 13 | 14 | fig, ax = plt.subplots(figsize=(10, 6)) 15 | 16 | # Plotting the bars 17 | bars = ax.bar(x - width/2, rephrased, width, label='Rephrased', color='skyblue') 18 | bars2 = ax.bar(x + width/2, without_rephrasing, width, label='Without Rephrasing', color='lightcoral') 19 | 20 | # Adding lines 21 | ax.plot(x, rephrased, color='blue', marker='o', label='Rephrased Line') 22 | ax.plot(x, without_rephrasing, color='red', marker='s', label='Without Rephrasing Line') 23 | 24 | # Labels, title, and legend 25 | ax.set_xlabel('Datasets') 26 | ax.set_ylabel('Performance') 27 | ax.set_title('Performance on Different Datasets with and without Rephrasing') 28 | ax.set_xticks(x) 29 | ax.set_xticklabels(datasets) 30 | ax.legend() 31 | 32 | # Show the plot 33 | plt.show() -------------------------------------------------------------------------------- /scripts/results_processing/. …: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /scripts/run_all.sh: -------------------------------------------------------------------------------- 1 | while [[ $# -gt 0 ]]; do 2 | key="$1" 3 | case $key in 4 | -c|--config) 5 | config="$2" 6 | shift 2 7 | ;; 8 | *) 9 | echo "Error: Unknown option: $1" 10 | usage 11 | exit 1 12 | ;; 13 | esac 14 | done 15 | 16 | export PYTHONPATH=./ 17 | 18 | #models=("meta-llama/Meta-Llama-3-8B-Instruct" "mistralai/Mistral-7B-Instruct-v0.2" "google/gemma-1.1-2b-it" "mistralai/Mixtral-8x7B-Instruct-v0.1" "google/gemma-1.1-7b-it" "meta-llama/Meta-Llama-3-70B-Instruct" "mistralai/Mixtral-8x22B-Instruct-v0.1") 19 | models=("meta-llama/Meta-Llama-3-8B-Instruct" "mistralai/Mistral-7B-Instruct-v0.2" "google/gemma-1.1-2b-it" "mistralai/Mixtral-8x7B-Instruct-v0.1" "google/gemma-1.1-7b-it" "meta-llama/Meta-Llama-3-70B-Instruct" "mistralai/Mixtral-8x22B-Instruct-v0.1") 20 | 21 | for model in "${models[@]}"; do 22 | python scripts/run_model_predictions.py --model_name "${model}" --config "${config}" 23 | done -------------------------------------------------------------------------------- /scripts/run_model_predictions.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from toghter_api_keys import API_KEY 3 | os.environ['TOGETHER_API_KEY'] = API_KEY 4 | import shutil 5 | from argparse import ArgumentParser 6 | import json 7 | import sys 8 | print("path to model wrepper", sys.path) 9 | import os 10 | os.environ["CUDA_LAUNCH_BLOCKING"]="1" 11 | from model_wrappers.hf_pipline_wrap import HfPipelineWrap 12 | import gc 13 | import torch 14 | from tqdm import tqdm 15 | import pickle as pkl 16 | from accelerate import Accelerator 17 | # os.environ['CUDA_VISIBLE_DEVICES'] = '6' 18 | from accelerate.utils import DeepSpeedPlugin 19 | QUANTIZED_MODELS = {'mistralai/Mixtral-8x22B-Instruct-v0.1': '4bit', 20 | "meta-llama/Meta-Llama-3-70B-Instruct": '4bit', 21 | 'mistralai/Mixtral-8x7B-Instruct-v0.1':'4bit', 22 | 'Qwen/Qwen2-72B-Instruct':'4bit'} 23 | 24 | 25 | def read_metadata(metadata_path): 26 | with open(metadata_path, 'r') as f: 27 | meta = json.load(f) 28 | fill_default_arguments(meta) 29 | return meta 30 | 31 | 32 | def handle_out_path_overrides(metadata): 33 | output_dir = os.path.join(metadata['out_dir'], metadata['run_name']) 34 | out_dir_exists = os.path.exists(output_dir) 35 | override_all = metadata['override'] 36 | path_to_metadata = os.path.join(output_dir, "metadata.json") 37 | run_mode = 'resume' 38 | 39 | if out_dir_exists and not override_all: 40 | print("Warning: resuming run from current point") 41 | if override_all: 42 | run_mode = 'override' 43 | if out_dir_exists: 44 | shutil.rmtree(output_dir) 45 | print("Warning: out_dir exists, overriding all files in it.") 46 | else: 47 | print("Warning: out_dir does not exist, has nothing to override. Ignoring override flag.") 48 | 49 | os.makedirs(output_dir, exist_ok=True) 50 | with open(path_to_metadata, 'w') as f_meta: 51 | json.dump(metadata, f_meta, indent=2) 52 | print(f"Saving the current metadata to {path_to_metadata}") 53 | 54 | return output_dir, run_mode 55 | 56 | 57 | def fill_default_arguments(metadata): 58 | metadata['resume'] = metadata.get('resume', False) 59 | metadata['temperature'] = metadata.get('temperature', 0.8) 60 | metadata['batch_size'] = metadata.get('batch_size', 8) 61 | metadata['num_demonstrations'] = metadata.get('num_demonstrations', 3) 62 | 63 | 64 | def run_all(all_datasets, metadata, output_dir, truncation, run_mode, model_name, togehter_mode): 65 | if not togehter_mode: 66 | model = load(model_name, metadata) 67 | else: 68 | model = None 69 | for dataset in all_datasets: 70 | all_ds_instances = dataset["instances"] 71 | print(f"\nRunning model {model_name} on dataset {dataset['name']}") 72 | for i, ds_instance in tqdm(enumerate(all_ds_instances), total=len(all_ds_instances)): 73 | out = os.path.join(output_dir, dataset['name'], f"{model_name}_{i}.json") 74 | if os.path.exists(out) and run_mode == 'resume': 75 | print(f"Output file {out} exists, skipping...") 76 | continue 77 | if not togehter_mode: 78 | if truncation == 'max': 79 | truncation = model.get_max_window() 80 | 81 | ds_instance.predict(model=model, 82 | out_path=out, 83 | num_truncation_tokens=truncation) 84 | else: 85 | ds_instance.predict_togehter(model_name=model_name, 86 | out_path=out, 87 | num_truncation_tokens=None) 88 | gc.collect() 89 | 90 | def run_all_not_togather(all_datasets, metadata, output_dir, truncation, run_mode, model_name): 91 | model = load(model_name, metadata) 92 | for dataset in all_datasets: 93 | all_ds_instances = dataset["instances"] 94 | print(f"\nRunning model {model_name} on dataset {dataset['name']}") 95 | for i, ds_instance in tqdm(enumerate(all_ds_instances), total=len(all_ds_instances)): 96 | out = os.path.join(output_dir, dataset['name'], f"{model_name}_{i}.json") 97 | if os.path.exists(out) and run_mode == 'resume': 98 | print(f"Output file {out} exists, skipping...") 99 | continue 100 | if truncation == 'max': 101 | # truncation = model.get_max_window() 102 | truncation = None 103 | ds_instance.predict(model=model, 104 | out_path=out, 105 | num_truncation_tokens=truncation) 106 | 107 | gc.collect() 108 | 109 | def load(model_name, metadata): 110 | gc.collect() 111 | torch.cuda.empty_cache() 112 | load_in_4_bit = load_in_8_bit = False 113 | if model_name in QUANTIZED_MODELS: 114 | if QUANTIZED_MODELS[model_name] == '4bit': 115 | load_in_4_bit = True 116 | elif QUANTIZED_MODELS[model_name] == '8bit': 117 | load_in_8_bit = True 118 | else: 119 | print("QUANTIZATION MODE NOT SUPPORTED. LOADING IN DEFAULT DTYPE.") 120 | 121 | model = HfPipelineWrap(model_name, metadata['temperature'], metadata['batch_size'], 122 | load_in_4_bit=load_in_4_bit, load_in_8_bit=load_in_8_bit) 123 | gc.collect() 124 | torch.cuda.empty_cache() 125 | # deepspeed_config = { 126 | # "fp16": { 127 | # "enabled": True # Enable FP16 for memory efficiency 128 | # }, 129 | # "zero_optimization": { 130 | # "stage": 2, # You can adjust the ZeRO stage based on memory needs 131 | # "offload_optimizer": { 132 | # "device": "cpu" # Optional: Offload optimizer to CPU for saving GPU memory 133 | # } 134 | # } 135 | # } 136 | # 137 | # # Set up DeepSpeedPlugin with the config 138 | # deepspeed_plugin = DeepSpeedPlugin(config=deepspeed_config) 139 | # accelerator = Accelerator() 140 | # model = accelerator.prepare(model) 141 | return model 142 | 143 | 144 | def get_truncation_strategy(metadata): 145 | truncation = metadata.get('truncation_strategy') # one of {max, min, set}. If set, max_num_tokens must be set. 146 | max_num_tokens = metadata.get('max_num_tokens') 147 | if truncation is None: 148 | raise ValueError("truncation_strategy must be set.") 149 | if truncation == 'max' and max_num_tokens is not None: 150 | raise ValueError("If truncation_strategy is 'max', max_num_tokens must be None.") 151 | elif truncation == 'set': 152 | if max_num_tokens is None: 153 | raise ValueError("If truncation_strategy is 'set', max_num_tokens must be set.") 154 | truncation = max_num_tokens 155 | 156 | return truncation 157 | 158 | 159 | if __name__ == '__main__': 160 | parser = ArgumentParser() 161 | parser.add_argument('--config', required=True) 162 | parser.add_argument('--model_name', required=True) 163 | parser.add_argument('--run_together', action='store_true', help="If set, runs in together mode.") 164 | args = parser.parse_args() 165 | metadata_dict = read_metadata(args.config) 166 | out_dir, running_mode = handle_out_path_overrides(metadata_dict) 167 | datasets_pickle = os.path.join(out_dir, metadata_dict["datasets_pickle_path"]) 168 | print(f"Loading datasets from {datasets_pickle}") 169 | with open(datasets_pickle, 'rb') as f: 170 | datasets_list = pkl.load(f) 171 | truncation_strategy = get_truncation_strategy(metadata_dict) 172 | print(f"run output dir in {out_dir}") 173 | run_all(all_datasets=datasets_list, metadata=metadata_dict, output_dir=out_dir, 174 | truncation=truncation_strategy, run_mode=running_mode, model_name=args.model_name, togehter_mode = args.run_together) 175 | 176 | --------------------------------------------------------------------------------