├── requirements.txt ├── experiments ├── api │ ├── run_api_model.py │ ├── run_api_single_task.py │ ├── openai_api.py │ ├── anthropic_api.py │ └── api.py └── hf │ └── run_hf_model.py ├── LICENSE ├── .gitignore ├── submission ├── verify_submission.py ├── prepare_submission.py └── verify_task.py ├── README.md └── zero_scrolls_datasets.bib /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | fire 3 | pandas==2.0.3 4 | datasets==2.13.1 5 | 6 | # hf 7 | tokenizers==0.13.3 8 | sentencepiece==0.1.99 9 | accelerate==0.21.0 10 | 11 | # api experiments 12 | openai==0.27.8 13 | tiktoken==0.4.0 14 | anthropic==0.3.4 -------------------------------------------------------------------------------- /experiments/api/run_api_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from fire import Fire 5 | 6 | 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 8 | from experiments.hf.run_hf_model import datasets 9 | from experiments.api.run_api_single_task import generate_predictions_using_api 10 | 11 | 12 | def main(model_name: str, limit_to_n_examples: int = None): 13 | for dataset in datasets: 14 | print(f"Starting with {dataset}") 15 | generate_predictions_using_api(dataset_name=dataset, model_name=model_name, 16 | limit_to_n_examples=limit_to_n_examples) 17 | 18 | 19 | if __name__ == '__main__': 20 | Fire(main) 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 TAU NLP Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /experiments/api/run_api_single_task.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | from pathlib import Path 4 | 5 | from fire import Fire 6 | from tqdm import tqdm 7 | 8 | from experiments.api.anthropic_api import AnthropicAPI 9 | from experiments.api.openai_api import OpenAIAPI 10 | from datasets import load_dataset 11 | 12 | def generate_predictions_using_api(dataset_name: str, model_name: str = "text-davinci-003", 13 | log_progress_every_n_examples=20, 14 | limit_to_n_examples=None): 15 | model_folder_name = model_name.replace("-", "_") 16 | if model_name in ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]: 17 | api = OpenAIAPI(model_name, dataset_name) 18 | elif model_name in ["claude-v1","claude-v1.3"]: 19 | api = AnthropicAPI(model_name, dataset_name) 20 | else: 21 | raise ValueError(f"model_name {model_name} not supported") 22 | 23 | api.init_api() 24 | # load task data 25 | zero_scrolls_dataset = load_dataset("tau/zero_scrolls",dataset_name)["test"] 26 | preds_folder_path = Path(f"generations/api/{model_folder_name}") 27 | preds_folder_path.mkdir(parents=True, exist_ok=True) 28 | 29 | 30 | print(f"generating predictions for {dataset_name} with OpenAI {model_name}") 31 | 32 | # API setup and parameters 33 | parameters = api.init_params() 34 | # with open(predictions_file_path, 'a') as f_out: 35 | generations = dict() 36 | for i, example in tqdm(enumerate(zero_scrolls_dataset)): 37 | if limit_to_n_examples is not None and i >= limit_to_n_examples: 38 | print( 39 | f"Breaking when limit_to_n_examples is reached. i={i}, limit_to_n_examples={limit_to_n_examples}, generated {len(generations)} predictions") 40 | break 41 | 42 | prompt = api.build_prompt(example) 43 | api.preprocess_parameters(parameters, prompt) 44 | 45 | time.sleep(0.5) # helps with rate limits 46 | response = api.call(parameters) 47 | output = api.build_output(example, prompt, parameters, response) 48 | 49 | generations[example["id"]] = output["prediction"] 50 | if i % log_progress_every_n_examples == 0: 51 | print( 52 | f'generated {len(generations)} examples from {dataset_name} using {model_name}') 53 | 54 | predictions_file_path = preds_folder_path / f"preds_{dataset_name}.json" 55 | with open(predictions_file_path, 'w') as f_out: 56 | json.dump(generations, f_out, indent=4) 57 | print( 58 | f'finished generating {len(generations)} predictions for {dataset_name} using OpenAI {model_name}') 59 | 60 | 61 | if __name__ == '__main__': 62 | Fire(generate_predictions_using_api) 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /submission/verify_submission.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import namedtuple 3 | import pandas as pd 4 | import logging 5 | 6 | from verify_task import main as evaluate_dataset, DATASETS 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | EXPECTED_DF_COLS = {"Task", "ID", "Prediction"} 11 | EXPECTED_TASKS = [ 12 | "gov_report", 13 | "summ_screen_fd", 14 | "qmsum", 15 | "narrative_qa", 16 | "qasper", 17 | "quality", 18 | "squality", 19 | "musique", 20 | "space_digest", 21 | "book_sum_sort" 22 | ] 23 | assert set(EXPECTED_TASKS).issubset(DATASETS) 24 | BenchmarkEvaluatorArgs = namedtuple( 25 | "BenchmarkEvaluatorArgs", 26 | "all_predictions split cache_dir output_dir internal_call", 27 | ) 28 | DatasetEvaluatorArgs = namedtuple( 29 | "DatasetEvaluatorArgs", 30 | "predictions dataset_name split cache_dir output_dir internal_call", 31 | ) 32 | 33 | 34 | def main(args): 35 | all_predictions = args.all_predictions 36 | if isinstance(all_predictions, str): 37 | all_predictions = load_predictions_df(all_predictions) 38 | errors = 0 39 | for task in EXPECTED_TASKS: 40 | 41 | log.info(f"Evaluating the results for task {task} with task {task}...") 42 | task_json = ( 43 | all_predictions[all_predictions.Task == task][["ID", "Prediction"]] 44 | .set_index("ID")["Prediction"] 45 | .to_dict() 46 | ) 47 | evaluator_obj = DatasetEvaluatorArgs( 48 | predictions=task_json, 49 | dataset_name=task, 50 | split=args.split, 51 | cache_dir=args.cache_dir, 52 | output_dir=args.output_dir, 53 | internal_call=True 54 | ) 55 | try: 56 | evaluate_dataset(evaluator_obj, raise_on_errors=True) 57 | except Exception as e: 58 | errors += 1 59 | log.exception(f"Error for task: {task}:\n{e}") 60 | continue 61 | log.info(f"task: {task} is valid") 62 | if errors: 63 | msg = f"Found {errors} errors in the submission, see output files in {args.output_dir} for details." 64 | raise ValueError(msg) 65 | else: 66 | print("The verification was successful.") 67 | 68 | 69 | def load_predictions_df(file_path): 70 | try: 71 | df = safe_read_csv(file_path) 72 | except Exception as e: 73 | raise ValueError(f"Failed to read the csv with pandas: {e}") 74 | 75 | cols = set(df.columns) 76 | if cols != EXPECTED_DF_COLS: 77 | raise ValueError(f"csv file has invalid format. Expected columns {EXPECTED_DF_COLS} and got {cols} instead") 78 | 79 | tasks = set(df.Task.unique()) 80 | if tasks != set(EXPECTED_TASKS): 81 | raise ValueError( 82 | f"csv file does not contain predictions for the expected tasks. " 83 | f"Expected tasks {sorted(EXPECTED_TASKS)} and got {sorted(tasks)} instead" 84 | ) 85 | 86 | return df 87 | 88 | 89 | def safe_read_csv(file_path): 90 | # https://stackoverflow.com/a/33952294 91 | return pd.read_csv(file_path, dtype=object, keep_default_na=False, na_values=["!@#$%^&*()"]) 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser(description="Evaluate the predictions for the full SCROLLS benchmark") 96 | parser.add_argument( 97 | "--all_predictions", 98 | type=str, 99 | help="Path to the file with all of the predictions or the actual predictions", 100 | required=True, 101 | ) 102 | parser.add_argument("--output_dir", type=str, help="Directory of the output metrics file", required=True) 103 | parser.add_argument("--split", type=str, help="The split to evaluate on", default="test") 104 | parser.add_argument("--internal_call", type=str, help="For internal use", default=False) 105 | parser.add_argument( 106 | "--cache_dir", type=str, help="Cache dir for the dataset download", default=None, required=False 107 | ) 108 | args = parser.parse_args() 109 | 110 | main(args) 111 | -------------------------------------------------------------------------------- /experiments/api/openai_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import openai 5 | import tiktoken 6 | 7 | sys.path.append(os.path.dirname((os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) 8 | from experiments.api.api import APIRunner 9 | 10 | DAVINCI = "text_davinci_003" 11 | ChatGPT = "gpt_3.5_turbo" 12 | GPT4 = "gpt_4" 13 | DAVINCI_MAX_INPUT_OUTPUT_TOKENS = 4096 14 | GPT4_MAX_INPUT_OUTPUT_TOKENS = 8192 15 | 16 | 17 | class OpenAIAPI(APIRunner): 18 | 19 | def __init__(self, model_name: str, dataset_name: str): 20 | super().__init__(model_name, dataset_name) 21 | self.temperature = 0 22 | # self.top_p = 0 23 | self.top_k = None 24 | self.tokenizer = tiktoken.encoding_for_model(self.model_name) 25 | model_name_underscore = self.model_name.replace("-", "_") 26 | self.is_chat_api = model_name_underscore in {ChatGPT, GPT4} 27 | self.is_gpt4 = model_name_underscore == GPT4 28 | 29 | @property 30 | def max_input_output_tokens(self): 31 | return GPT4_MAX_INPUT_OUTPUT_TOKENS if self.is_gpt4 else DAVINCI_MAX_INPUT_OUTPUT_TOKENS 32 | 33 | def init_api(self): 34 | openai.organization = os.getenv("OPENAI_ORG") 35 | openai.api_key = os.getenv("OPENAI_API_KEY") 36 | 37 | @property 38 | def max_generation_tokens_key(self): 39 | return "max_tokens" 40 | 41 | def parse_finish_reason(self, response): 42 | return response.choices[0]["finish_reason"] 43 | 44 | def build_output(self, example, prompt, parameters, response): 45 | output = super().build_output(example, prompt, parameters, response) 46 | 47 | output.update({ 48 | "index": response.choices[0]["index"], 49 | }) 50 | 51 | return output 52 | 53 | def parse_model_name(self, parameters, response): 54 | return response.model 55 | 56 | def build_prompt(self, example): 57 | if self.is_chat_api: 58 | if self.dataset_name not in APIRunner.summarization_datasets: 59 | format_request, _, answer_type = self.get_chat_format_request_tag_and_answer_type(example) 60 | self.insert_format_request(example, format_request) 61 | 62 | example["input"] = example["input"][:example['query_end_index']] 63 | tokenized = self.tokenizer.encode(example['input']) 64 | if len(tokenized) <= self.max_input_output_tokens - self.max_generation_tokens: 65 | return example['input'] 66 | 67 | query_and_answer_prompt = example['input'][example['query_start_index']:] 68 | truncation_seperator = example['truncation_seperator'] 69 | 70 | suffix_tokenized = self.tokenizer.encode(truncation_seperator + query_and_answer_prompt) 71 | 72 | max_tokens_for_input = self.get_max_document_tokens(len(suffix_tokenized)) 73 | 74 | tokenized_trimmed = tokenized[:max_tokens_for_input] 75 | prompt = self.tokenizer.decode(tokenized_trimmed) + truncation_seperator + query_and_answer_prompt 76 | 77 | return prompt 78 | 79 | def get_max_document_tokens(self, n_suffix_tokens): 80 | max_input_tokens = super().get_max_document_tokens(n_suffix_tokens) 81 | if self.is_chat_api: 82 | max_input_tokens -= 9 83 | return max_input_tokens 84 | def preprocess_parameters(self, parameters, prompt): 85 | if self.is_chat_api: 86 | parameters["messages"] = [ 87 | {"role": "user", "content": prompt} 88 | ] 89 | 90 | if self.is_gpt4: 91 | parameters["model"] = GPT4.replace("_", "-") 92 | 93 | else: 94 | super().preprocess_parameters(parameters, prompt) 95 | 96 | def get_number_of_tokens(self, prompt): 97 | return len(self.tokenizer.encode(prompt)) 98 | 99 | def parse_prediction(self, response): 100 | if self.is_chat_api: 101 | return response.choices[0].message.content 102 | return response.choices[0].text 103 | 104 | def call(self, parameters): 105 | if self.is_chat_api: 106 | return openai.ChatCompletion.create(**parameters) 107 | return openai.Completion.create(**parameters) 108 | -------------------------------------------------------------------------------- /experiments/api/anthropic_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import anthropic 5 | 6 | sys.path.append(os.path.dirname((os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) 7 | from experiments.api.api import APIRunner 8 | 9 | CLAUDE_V1 = "claude_v1" 10 | 11 | class AnthropicAPI(APIRunner): 12 | 13 | def __init__(self, model_name: str, dataset_name: str, format_prompt_loc: str = "start", 14 | max_input_output_tokens: int = None): 15 | super().__init__(model_name, dataset_name) 16 | self._client = None 17 | self.top_p = -1 18 | self.temperature = 1 # https://console.anthropic.com/docs/api/reference 19 | self.top_k = 1 20 | 21 | self.max_tokens = 8000 22 | self.tags_in_prompt = True 23 | 24 | 25 | @property 26 | def max_input_output_tokens(self): 27 | return self.max_tokens # https://console.anthropic.com/docs/prompt-design#prompt-length 28 | 29 | def init_api(self): 30 | self._client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) 31 | self.tokenizer = self._client.get_tokenizer() 32 | 33 | @property 34 | def max_generation_tokens_key(self): 35 | return "max_tokens_to_sample" 36 | 37 | @property 38 | def max_generation_tokens(self): 39 | return super().max_generation_tokens # + 64 40 | 41 | def get_number_of_tokens(self, string): 42 | return self._client.count_tokens(string) 43 | 44 | def parse_finish_reason(self, response): 45 | return response.stop_reason 46 | 47 | def init_params(self): 48 | params = super().init_params() 49 | params["stop_sequences"] = [anthropic.HUMAN_PROMPT] 50 | return params 51 | 52 | def build_prompt(self, example): 53 | 54 | if self.dataset_name in APIRunner.summarization_datasets: 55 | anthropic_suffix = anthropic.AI_PROMPT 56 | else: 57 | format_request, tag, answer_type = self.get_chat_format_request_tag_and_answer_type(example) 58 | self.insert_format_request(example, format_request) 59 | anthropic_suffix = f"{anthropic.AI_PROMPT}" 60 | if self.tags_in_prompt: 61 | anthropic_suffix += f" {answer_type}: <{tag}>" 62 | 63 | input_without_suffix = example['input'][:example['query_end_index']] 64 | anthropic_prompt = f"{anthropic.HUMAN_PROMPT} {input_without_suffix}{anthropic_suffix}" # https://console.anthropic.com/docs/prompt-design/classification 65 | 66 | tokenized = self.tokenizer.encode(anthropic_prompt) 67 | if len(tokenized.ids) <= self.max_input_output_tokens - self.max_generation_tokens: 68 | return anthropic_prompt 69 | 70 | query = example['input'][example['query_start_index']:example['query_end_index']] 71 | truncation_seperator = example['truncation_seperator'] 72 | anthropic_suffix = f"{truncation_seperator}{query}{anthropic_suffix}" 73 | anthropic_prefix_and_suffix_tokenized = self.tokenizer.encode( 74 | f"{anthropic.HUMAN_PROMPT} {anthropic_suffix}").ids 75 | max_tokens_for_input = self.get_max_document_tokens(len(anthropic_prefix_and_suffix_tokenized)) 76 | char_idx_of_max_tokens_for_input = tokenized.offsets[max_tokens_for_input][0] 77 | input_without_suffix_trimmed = input_without_suffix[:char_idx_of_max_tokens_for_input] 78 | anthropic_prompt = f"{anthropic.HUMAN_PROMPT} {input_without_suffix_trimmed}{anthropic_suffix}" 79 | return anthropic_prompt 80 | 81 | def parse_prediction(self, response): 82 | return response.completion 83 | 84 | def parse_model_name(self, parameters, response): 85 | return response.model 86 | 87 | def call(self, parameters): 88 | return self._client.completions.create(**parameters) 89 | 90 | def get_chat_format_request_tag_and_answer_type(self, example): 91 | format_request, tag, answer_type = super().get_chat_format_request_tag_and_answer_type(example) 92 | if self.tags_in_prompt: 93 | if format_request[-1] == ".": 94 | format_request = format_request[:-1] 95 | format_request += f", and please highlight your final {answer_type.lower()} with <{tag}> tags." 96 | return format_request, tag, answer_type 97 | -------------------------------------------------------------------------------- /submission/prepare_submission.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import json 7 | 8 | SUBMISSION_LINK = "https://zero.scrolls-benchmark.com/submission" 9 | TASKS_MAPPING = { 10 | "gov_report_file": "gov_report", 11 | "summ_screen_fd_file": "summ_screen_fd", 12 | "qmsum_file": "qmsum", 13 | "squality_file": "squality", 14 | "qasper_file": "qasper", 15 | "narrative_qa_file": "narrative_qa", 16 | "quality_file": "quality", 17 | "musique_file": "musique", 18 | "space_digest_file": "space_digest", 19 | "book_sum_sort_file": "book_sum_sort", 20 | } 21 | COLUMNS = ["Task", "ID", "Prediction"] 22 | 23 | 24 | def safe_read_csv(file_path): 25 | # https://stackoverflow.com/a/33952294 26 | return pd.read_csv(file_path, dtype=object, keep_default_na=False, na_values=["!@#$%^&*()"]) 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser(description="Prepare ZeroSCROLLS prediction") 31 | parser.add_argument("--output_dir", type=str, help="Path to output the prediction file", required=True) 32 | parser.add_argument( 33 | "--qmsum_file", type=str, help="The path to the qmsum dataset json file containing prediction", required=True 34 | ) 35 | parser.add_argument( 36 | "--qasper_file", 37 | type=str, 38 | help="The path to the qasper dataset json file containing prediction", 39 | required=True, 40 | ) 41 | parser.add_argument( 42 | "--summ_screen_fd_file", 43 | type=str, 44 | help="The path to the summ_screen dataset json file containing prediction", 45 | required=True, 46 | ) 47 | parser.add_argument( 48 | "--quality_file", 49 | type=str, 50 | help="The path to the quality dataset json file containing prediction", 51 | required=True, 52 | ) 53 | parser.add_argument( 54 | "--narrative_qa_file", 55 | type=str, 56 | help="The path to the narrative_qa dataset json file containing prediction", 57 | required=True, 58 | ) 59 | parser.add_argument( 60 | "--gov_report_file", 61 | type=str, 62 | help="The path to the gov_report dataset json file containing prediction", 63 | required=True, 64 | ) 65 | parser.add_argument( 66 | "--squality_file", 67 | type=str, 68 | help="The path to the squality dataset json file containing prediction", 69 | required=True, 70 | ) 71 | parser.add_argument( 72 | "--musique_file", 73 | type=str, 74 | help="The path to the musique dataset json file containing prediction", 75 | required=True, 76 | ) 77 | parser.add_argument( 78 | "--space_digest_file", 79 | type=str, 80 | help="The path to the space_digest dataset json file containing prediction", 81 | required=True, 82 | ) 83 | parser.add_argument( 84 | "--book_sum_sort_file", 85 | type=str, 86 | help="The path to the book_sum_sort dataset json file containing prediction", 87 | required=True, 88 | ) 89 | args = parser.parse_args() 90 | 91 | tasks_dfs = pd.DataFrame(columns=COLUMNS, data=[]) 92 | for file_key, task_name in TASKS_MAPPING.items(): 93 | print(f"Adding prediction for {task_name} from {file_key}...") 94 | with open(getattr(args, file_key)) as f: 95 | task_data = json.load(f) 96 | task_df = pd.DataFrame.from_dict(task_data, orient="index", columns=COLUMNS[-1:]).reset_index(drop=False) 97 | task_df[COLUMNS[0]] = task_name 98 | task_df[COLUMNS[1]] = task_df["index"] 99 | tasks_dfs = pd.concat((tasks_dfs, task_df[COLUMNS])) 100 | 101 | os.makedirs(args.output_dir, exist_ok=True) 102 | outfile = os.path.join(args.output_dir, "zero_scrolls_predictions.csv") 103 | print(f"Saving the complete prediction file to: {outfile}") 104 | tasks_dfs = tasks_dfs.reset_index(drop=True) 105 | tasks_dfs.to_csv(outfile, index=False) 106 | 107 | print("validating submission file is exactly the same as expected") 108 | recovered_tasks_dfs = safe_read_csv(outfile) 109 | assert len(recovered_tasks_dfs) == len(tasks_dfs) 110 | assert recovered_tasks_dfs.columns.tolist() == tasks_dfs.columns.tolist() 111 | assert np.all(recovered_tasks_dfs.values == tasks_dfs.values) 112 | 113 | print(f"Your benchmark prediction file is ready. If it contains prediction for the test sets please head over to {SUBMISSION_LINK} to submit to the ZeroSCROLLS leaderboard.") 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /experiments/api/api.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | dataset_to_leave_tokens_for_generations = { 4 | "gov_report": 1024, 5 | "summ_screen_fd": 512, 6 | "qmsum": 512, 7 | "qasper": 128, 8 | "narrative_qa": 64, 9 | "quality": 10, 10 | "musique": 32, 11 | "squality": 512, 12 | "space_digest": 36, 13 | "book_sum_sort": 256, 14 | } 15 | 16 | class APIRunner: 17 | summarization_datasets = ["gov_report", "summ_screen_fd", "qmsum", "squality"] 18 | 19 | def __init__(self, model_name: str, dataset_name: str, min_ms_between_api_calls: int = 20): 20 | self.top_k = None 21 | self.temperature = None # meaning could change from one api to another 22 | self.top_p = None 23 | self.model_name = model_name 24 | self.min_ms_between_api_calls = min_ms_between_api_calls 25 | self.dataset_name = dataset_name 26 | self.init_api() 27 | 28 | @property 29 | def max_input_output_tokens(self): 30 | # max tokens for input + output 31 | raise NotImplementedError("max_input_output_tokens") 32 | 33 | def init_api(self): 34 | raise NotImplementedError("init_api") 35 | 36 | def init_params(self): 37 | params = { 38 | self.max_generation_tokens_key: self.max_generation_tokens, 39 | "model": self.model_name 40 | } 41 | 42 | if self.temperature is not None: 43 | params["temperature"] = self.temperature 44 | 45 | if self.top_k is not None: 46 | params[self.top_k_key] = self.top_k 47 | 48 | if self.top_p is not None: 49 | params[self.top_p_key] = self.top_p 50 | 51 | return params 52 | 53 | @property 54 | def max_generation_tokens_key(self): 55 | raise NotImplementedError("max_generation_tokens_key_name") 56 | 57 | @property 58 | def top_k_key(self): 59 | return "top_k" 60 | 61 | @property 62 | def top_p_key(self): 63 | return "top_p" 64 | 65 | @property 66 | def max_generation_tokens(self): 67 | return dataset_to_leave_tokens_for_generations[self.dataset_name] 68 | 69 | def get_number_of_tokens(self, string): 70 | raise NotImplementedError("get_number_of_tokens") 71 | 72 | def get_max_document_tokens(self, n_suffix_tokens): 73 | return self.max_input_output_tokens - n_suffix_tokens - self.max_generation_tokens - 1 74 | 75 | def parse_finish_reason(self, response): 76 | raise NotImplementedError("parse_finish_reason") 77 | 78 | def parse_prediction(self, response): 79 | raise NotImplementedError("parse_prediction") 80 | 81 | def build_prompt(self, example, ): 82 | raise NotImplementedError("build_input") 83 | 84 | def preprocess_parameters(self, parameters, prompt): 85 | parameters["prompt"] = prompt 86 | 87 | def parse_model_name(self, parameters, response): 88 | raise NotImplementedError("build_input") 89 | 90 | def build_output(self, example, prompt, parameters, response): 91 | prediction = self.parse_prediction(response) 92 | output = { 93 | "id": example["id"], 94 | "model": self.parse_model_name(parameters, response), 95 | "original_example_input": example["input"], 96 | "prompt": prompt, 97 | "max_input_output_tokens": self.max_input_output_tokens, 98 | "n_input_tokens": self.get_number_of_tokens(prompt), 99 | "n_generated_tokens": self.get_number_of_tokens(prediction), 100 | "finish_reason": self.parse_finish_reason(response), 101 | "temperature": parameters["temperature"] 102 | } 103 | 104 | if self.top_k_key in parameters: 105 | output[self.top_k_key] = parameters[self.top_k_key] 106 | 107 | if self.top_p_key in parameters: 108 | output[self.top_p_key] = parameters[self.top_p_key] 109 | 110 | output["prediction"] = prediction 111 | 112 | return output 113 | 114 | def call(self, prompt): 115 | raise NotImplementedError("call") 116 | 117 | def get_chat_format_request_tag_and_answer_type(self, example): 118 | answer_type = example["input"][example['query_end_index']:].strip().replace(":", "") 119 | tag = answer_type.lower().replace(" ", "_") 120 | format_request = "Do not provide any explanation." 121 | return format_request, tag, answer_type 122 | 123 | def insert_format_request(self, example, format_request): 124 | instruction_end_index = example['input'].find("\n\n") 125 | instruction = example['input'][:instruction_end_index] 126 | example['input'] = f"{instruction.strip()} {format_request}{example['input'][instruction_end_index:]}" 127 | 128 | for key in ["document_start_index", "document_end_index", "query_start_index", "query_end_index"]: 129 | example[key] += len(format_request) + 1 130 | 131 | 132 | def _ms_since_epoch(): 133 | epoch = datetime.utcfromtimestamp(0) 134 | now = datetime.utcnow() 135 | return int((now - epoch).total_seconds() * 1000) 136 | -------------------------------------------------------------------------------- /submission/verify_task.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from datasets import load_dataset 6 | 7 | DATASETS = [ 8 | "narrative_qa", 9 | "qasper", 10 | "summ_screen_fd", 11 | "gov_report", 12 | "qmsum", 13 | "quality", 14 | "squality", 15 | "musique", 16 | "space_digest", 17 | "book_sum_sort", 18 | ] 19 | 20 | 21 | def main(args, raise_on_errors=False): 22 | """ 23 | If raise_on_errors is True, raises ValueError on verification errors (after dumping the error descriptions). 24 | Otherwise, exists with an error code 25 | """ 26 | predictions = args.predictions 27 | dataset_name = args.dataset_name 28 | 29 | # Downloading and loading the dataset from the hub 30 | load_dataset_kwargs = { 31 | "path": "tau/zero_scrolls", 32 | "name": dataset_name, 33 | } 34 | if args.cache_dir is not None: 35 | load_dataset_kwargs["cache_dir"] = args.cache_dir 36 | load_dataset_kwargs["split"] = "test" 37 | seq2seq_dataset = load_dataset(**load_dataset_kwargs) 38 | 39 | # Prepare reference 40 | untokenized_dataset = drop_duplicates_in_input(seq2seq_dataset) 41 | id_to_labels = {instance["id"]: instance["outputs"] for instance in untokenized_dataset} 42 | 43 | # Prepare predictions 44 | if isinstance(predictions, str): 45 | with open(predictions) as f: 46 | id_to_pred = json.load(f) 47 | else: 48 | id_to_pred = predictions 49 | 50 | # Check for format errors 51 | errors, details = verify(id_to_pred, id_to_labels) 52 | 53 | out_file_path = get_errors_filename(args.output_dir, dataset_name) 54 | os.makedirs(args.output_dir, exist_ok=True) 55 | 56 | if len(errors) > 0: 57 | # Output errors 58 | errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors)) 59 | print(json.dumps(errors, indent=4)) 60 | print(f"See details in: {out_file_path}") 61 | with open(out_file_path, mode="w") as f: 62 | json.dump({"errors": errors, "details": details}, f, indent=4) 63 | if raise_on_errors: 64 | raise ValueError(f"Failed to evaluate due to: {errors_msg}") 65 | exit(os.EX_DATAERR) 66 | 67 | 68 | # Copied from baselines/src/utils/duplicates.py 69 | def drop_duplicates_in_input(untokenized_dataset): 70 | indices_to_keep = [] 71 | id_to_idx = {} 72 | outputs = [] 73 | for i, (id_, output) in enumerate(zip(untokenized_dataset["id"], untokenized_dataset["output"])): 74 | if id_ in id_to_idx: 75 | outputs[id_to_idx[id_]].append(output) 76 | continue 77 | indices_to_keep.append(i) 78 | id_to_idx[id_] = len(outputs) 79 | outputs.append([output]) 80 | untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices() 81 | untokenized_dataset = untokenized_dataset.remove_columns("output") 82 | untokenized_dataset = untokenized_dataset.add_column("outputs", outputs) 83 | return untokenized_dataset 84 | 85 | 86 | def get_errors_filename(outdir, dataset_name): 87 | return os.path.join(outdir, f"{dataset_name}_errors.json") 88 | 89 | 90 | def verify(id_to_pred, id_to_labels): 91 | errors = [] 92 | details = {"missing_keys": [], "redundant_keys": []} 93 | if not isinstance(id_to_pred, dict): 94 | errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}') 95 | else: 96 | if not all(isinstance(key, str) for key in id_to_pred.keys()): 97 | errors.append("All keys of the predictions dictionary must be strings") 98 | if not all(isinstance(value, str) for value in id_to_pred.values()): 99 | errors.append("All values of the predictions dictionary must be strings") 100 | if len(errors) == 0: 101 | predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys()) 102 | missing_keys = reference_keys - predictions_keys 103 | redundant_keys = predictions_keys - reference_keys 104 | 105 | if len(missing_keys) > 0: 106 | details["missing_keys"] = list(missing_keys) 107 | errors.append(f"There are missing example IDs.") 108 | else: 109 | del details["missing_keys"] 110 | 111 | if len(redundant_keys) > 0: 112 | details["redundant_keys"] = list(redundant_keys) 113 | errors.append(f"There are redundant example IDs.") 114 | else: 115 | del details["redundant_keys"] 116 | 117 | return errors, details 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser(description="verify ZeroSCROLLS predictions per dataset") 122 | parser.add_argument( 123 | "--predictions", type=str, help="Path to the predictions file or the actual predictions", required=True 124 | ) 125 | parser.add_argument("--dataset_name", type=str, help="Name of the dataset", choices=DATASETS, required=True) 126 | parser.add_argument("--output_dir", type=str, help="Directory of the output file", required=True) 127 | parser.add_argument("--internal_call", type=str, help="For internal use", default=False) 128 | parser.add_argument( 129 | "--cache_dir", type=str, help="Cache dir for the dataset download", default=None, required=False 130 | ) 131 | args = parser.parse_args() 132 | 133 | main(args) 134 | -------------------------------------------------------------------------------- /experiments/hf/run_hf_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from datetime import datetime 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | from datasets import load_dataset 10 | from fire import Fire 11 | from transformers import T5Tokenizer, T5ForConditionalGeneration 12 | from transformers import set_seed as hf_set_seed 13 | 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 15 | 16 | datasets = ['gov_report', 17 | 'summ_screen_fd', 18 | 'qmsum', 19 | 'qasper', 20 | 'narrative_qa', 21 | 'quality', 22 | 'musique', 23 | 'squality', 24 | 'space_digest', 25 | 'book_sum_sort'] 26 | 27 | model_to_max_input_tokens = { 28 | "google/flan-t5-xxl": 8192, 29 | "google/flan-t5-xl": 8192, 30 | "google/flan-t5-large": 8192, 31 | "google/flan-t5-base": 8192, 32 | "google/flan-t5-small": 8192, 33 | "google/flan-ul2": 8192, 34 | "bigscience/T0pp": 8192, 35 | } 36 | 37 | 38 | def trim_doc_keeping_suffix(tokenizer, tokenized_input_full, example, suffix_index, max_tokens, device): 39 | seperator_and_suffix = f"{example['truncation_seperator'].strip()}\n\n{example['input'][suffix_index:].strip()}\n" 40 | tokenized_seperator_and_suffix = tokenizer(seperator_and_suffix, return_tensors="pt").input_ids.to(device) 41 | tokenized_input_trimmed = tokenized_input_full[:, :max_tokens - tokenized_seperator_and_suffix.shape[1]] 42 | tokenized_input = torch.cat([tokenized_input_trimmed, tokenized_seperator_and_suffix], dim=1) 43 | return tokenized_input 44 | 45 | 46 | def process_model_input(tokenizer, example, max_tokens, device): 47 | tokenized_input_full = tokenizer(example["input"], return_tensors="pt").input_ids.to(device) 48 | if tokenized_input_full.shape[1] <= max_tokens: 49 | return tokenized_input_full 50 | 51 | seperator_and_query_text = example['truncation_seperator'] + example["input"][example['query_start_index']:] 52 | tokenized_seperator_and_query = tokenizer(seperator_and_query_text, return_tensors="pt").input_ids.to(device) 53 | input_without_query = example['input'][:example['query_start_index']] 54 | tokenized_input_without_query = tokenizer(input_without_query, return_tensors="pt").input_ids.to(device) 55 | tokenized_input_without_query = tokenized_input_without_query[:, 56 | :max_tokens - tokenized_seperator_and_query.shape[1]] 57 | 58 | tokenized_input = torch.cat([tokenized_input_without_query, tokenized_seperator_and_query], dim=1) 59 | return tokenized_input 60 | 61 | 62 | def main(model_name="google/flan-t5-small", generations_dir="generations", max_examples_per_task=-1): 63 | seed = 43 64 | random.seed(seed) 65 | np.random.seed(seed) 66 | hf_set_seed(seed) 67 | print("Params:") 68 | print(f"model: {model_name}") 69 | generations_dir = os.path.join(generations_dir, model_name.replace("/", "_").replace("-", "_")) 70 | print(f"generations_dir: {generations_dir}") 71 | print(f"max_examples_per_task: {max_examples_per_task}") 72 | print("=" * 50) 73 | time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S") 74 | print(f"time as start: {time}") 75 | 76 | print("Loading tokenizer") 77 | tokenizer = T5Tokenizer.from_pretrained(model_name) 78 | print(f"Loading model: {model_name}") 79 | device = "cuda" if torch.cuda.is_available() else "cpu" 80 | 81 | max_input_length = model_to_max_input_tokens[model_name] 82 | 83 | model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto", 84 | torch_dtype=torch.float16) 85 | 86 | model = model.eval() 87 | 88 | print(f"{model} model loaded!, device:{model.device}") 89 | 90 | print("Will write to:", generations_dir) 91 | os.makedirs(generations_dir, exist_ok=True) 92 | for dataset in datasets: 93 | generations = dict() 94 | print(f"Processing {dataset}") 95 | time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S") 96 | print(f"time as start {dataset}: {time}") 97 | print(f"Loading {dataset}") 98 | data = load_dataset("tau/zero_scrolls", dataset) 99 | print(f"Loaded {dataset}") 100 | 101 | for i, example in enumerate(data["test"]): 102 | 103 | if 0 < max_examples_per_task == i: 104 | print(f"Reached {max_examples_per_task} for {dataset}. Breaking") 105 | break 106 | 107 | model_input = process_model_input(tokenizer, example, max_input_length, device) 108 | 109 | prediction_token_ids = model.generate(model_input, 110 | max_new_tokens=1024, 111 | do_sample=False, 112 | top_p=0, 113 | top_k=0, 114 | temperature=1) 115 | 116 | predicted_text = tokenizer.decode(prediction_token_ids[0], skip_special_tokens=True) 117 | generations[example["id"]] = predicted_text 118 | 119 | out_file_path = os.path.join(generations_dir, f"preds_{dataset}.json") 120 | with open(out_file_path, 'w') as f_out: 121 | json.dump(generations, f_out, indent=4) 122 | 123 | print(f"Done generating {len(generations)} examples from {dataset}") 124 | time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S") 125 | print(f"time at end: {time}") 126 | print(f"Look for predictions in {generations_dir}") 127 | 128 | 129 | if __name__ == '__main__': 130 | Fire(main) 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ZeroSCROLLS 2 | 3 | This repository contains code to run inference on the [ZeroSCROLLS](https://www.zero.scrolls-benchmark.com/) benchmark. 4 | 5 | ## Setup 6 | 7 | * Install [torch](https://pytorch.org/get-started/locally/) 8 | * Install transformers 4.30.2 9 | * pip install -r requirements.txt 10 | 11 | 12 | ## Load the data 13 | - via [🤗 Datasets (huggingface/datasets)](https://huggingface.co/datasets/tau/zero_scrolls/viewer/book_sum_sort/test) library (recommended): 14 | ```python 15 | from datasets import load_dataset 16 | 17 | gov_report = load_dataset("tau/zero_scrolls", "gov_report", split="test") 18 | """ 19 | Options are: ["gov_report", "summ_screen_fd", "qmsum", "squality", "qasper","narrative_qa", "quality", "musique", "space_digest","book_sum_sort"] 20 | There is also a small number of examples (~20 per task) in a "validation" split, meant for eyeballing purposes 21 | """ 22 | 23 | ``` 24 | 25 | - via ZIP files, where each split is in a JSONL file: 26 | - [GovReport](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/gov_report.zip) 27 | - [SummScreenFD](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/summ_screen_fd.zip) 28 | - [QMSum](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/qmsum.zip) 29 | - [SQuALITY](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/squality.zip) 30 | - [Qasper](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/qasper.zip) 31 | - [NarrativeQA](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/narrative_qa.zip) 32 | - [QuALITY](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/quality.zip) 33 | - [MuSiQue](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/musique.zip) 34 | - [SpaceDigest](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/space_digest.zip) 35 | - [BookSumSort](https://huggingface.co/datasets/tau/zero_scrolls/resolve/main/book_sum_sort.zip) 36 | 37 | 38 | ## Inference with Huggingface models 39 | ```bash 40 | python experiments/hf/run_hf_model.py --model-name=google/flan-t5-small 41 | ``` 42 | 43 | Supported models: 44 | * google/flan-t5-small 45 | * google/flan-t5-base 46 | * google/flan-t5-large 47 | * google/flan-t5-xl 48 | * google/flan-t5-xxl 49 | * google/flan-ul2 50 | * bigscience/T0pp 51 | 52 | To add new models: 53 | * Add them to `model_to_max_input_tokens` in [experiments/hf/run_hf_model.py]((https://github.com/tau-nlp/scrolls/tree/main/baselines)) 54 | * Make sure to load them with the appropriate architecture (i.e. modify the model initialization from T5ForConditionalGeneration in the same file, if needed) 55 | 56 | ## Inference with APIs 57 | To run with models used in the [paper](https://arxiv.org/pdf/2305.14196.pdf)*: 58 | 59 | ```bash 60 | # if you want to use openai models 61 | export OPENAI_API_KEY= 62 | export OPENAI_ORG= 63 | 64 | # if you want to use anthropic models 65 | export ANTHROPIC_API_KEY= 66 | 67 | # if you want to limit the number of examples to run per task 68 | export MAX_EXAMPLES=10 69 | 70 | python experiments/api/run_api_model.py --model_name=gpt-3.5-turbo --limit_to_n_examples=$MAX_EXAMPLES 71 | ``` 72 | *These models and APIs tend to update, see the paper for the versions used in the baselines. 73 | 74 | Models supported: 75 | * text-davinci-003 76 | * gpt-3.5-turbo 77 | * gpt-4 78 | * claude-v1 79 | 80 | To add new a new API, you need to: 81 | * Implement a new class the inherits from [APIRunner](https://github.com/tau-nlp/zero_scrolls/blob/main/experiments/api/api.py#L16). 82 | * Working examples for OpenAI and Anthropic APIs can be found in [openai_api.py](https://github.com/tau-nlp/zero_scrolls/blob/main/experiments/api/openai_api.py) and [anthropic_api.py](https://github.com/tau-nlp/zero_scrolls/blob/main/experiments/api/anthropic_api.py) 83 | 84 | When using a prompt that includes opening XML tags, (e.g. "... Assistant: <answer>"), ensure that you post-process the generations to retain only the prefix before the closing XML tag generated by the model before submitting. 85 | 86 | ## Prepare submission 87 | To create a CSV file in the correct format for a leaderboard submission we recommend using our conversion script, [prepare_submission.py](https://github.com/tau-nlp/zero_scrolls/blob/main/prepare_submission.py). 88 | 89 | Its inputs: 90 | 91 | For each task, the predictions should be in a JSON file that is a mapping from an ID to a textual prediction: 92 | ```JSON 93 | { 94 | "example_id1": "prediction1", 95 | "example_id2": "prediction2", 96 | ... 97 | } 98 | ``` 99 | Please set: 100 | * `{dataset_name}_PREDS_FILE` to be the path to a JSON file in the format above containing your predictions for `{dataset_name}`. 101 | * `OUTPUT_DIR` to be the path you want the submission file will be saved to. 102 | 103 | Run: 104 | ```bash 105 | python submission/prepare_submission.py \ 106 | --gov_report_file GOV_REPORT_PREDS_FILE \ 107 | --summ_screen_fd_file SUMM_SCREEN_FD_PREDS_FILE \ 108 | --qmsum_file QMSUM_PREDS_FILE \ 109 | --squality_file SQUALITY_PREDS_FILE \ 110 | --qasper_file QASPER_PREDS_FILE \ 111 | --narrative_qa_file NARRATIVE_QA_PREDS_FILE \ 112 | --quality_file QUALITY_PREDS_FILE \ 113 | --musique_file MUSIQUE_PREDS_FILE \ 114 | --space_digest_file SPACE_DIGEST_PREDS_FILE \ 115 | --book_sum_sort_file BOOK_SUM_SORT_PREDS_FILE \ 116 | --output_dir OUTPUT_DIR 117 | ``` 118 | ### Verify your submission file 119 | Run: 120 | ```bash 121 | python submission/verify_submission.py \ 122 | --all_predictions SUBMMISION_FILE \ 123 | --output_dir OUTPUT_DIR 124 | ``` 125 | A valid submission file will result in the following line printed: 126 | ```bash 127 | The verification was successful. 128 | ``` 129 | Please fix any errors before making your submission. 130 | 131 | 132 | ## Leaderboard 133 | The live leaderboard is [here](https://www.zero.scrolls-benchmark.com/leaderboard). 134 | 135 | 136 | 137 | ## Citation 138 | ``` 139 | @inproceedings{shaham-etal-2023-zeroscrolls, 140 | title = "{Z}ero{SCROLLS}: A Zero-Shot Benchmark for Long Text Understanding", 141 | author = "Shaham, Uri and 142 | Ivgi, Maor and 143 | Efrat, Avia and 144 | Berant, Jonathan and 145 | Levy, Omer", 146 | editor = "Bouamor, Houda and 147 | Pino, Juan and 148 | Bali, Kalika", 149 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2023", 150 | month = dec, 151 | year = "2023", 152 | address = "Singapore", 153 | publisher = "Association for Computational Linguistics", 154 | url = "https://aclanthology.org/2023.findings-emnlp.536", 155 | doi = "10.18653/v1/2023.findings-emnlp.536", 156 | pages = "7977--7989" 157 | } 158 | ``` 159 | If you find the ZeroSCROLLS data useful, please make sure to cite also the original dataset papers: [[bibtex]](https://github.com/tau-nlp/zero_scrolls/tree/main/zero_scrolls_datasets.bib) 160 | -------------------------------------------------------------------------------- /zero_scrolls_datasets.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{huang2021govreport, 2 | title = "Efficient Attentions for Long Document Summarization", 3 | author = "Huang, Luyang and 4 | Cao, Shuyang and 5 | Parulian, Nikolaus and 6 | Ji, Heng and 7 | Wang, Lu", 8 | booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", 9 | month = jun, 10 | year = "2021", 11 | address = "Online", 12 | publisher = "Association for Computational Linguistics", 13 | url = "https://aclanthology.org/2021.naacl-main.112", 14 | doi = "10.18653/v1/2021.naacl-main.112", 15 | pages = "1419--1436" 16 | } 17 | 18 | @inproceedings{chen-etal-2022-summscreen, 19 | title = "{S}umm{S}creen: A Dataset for Abstractive Screenplay Summarization", 20 | author = "Chen, Mingda and 21 | Chu, Zewei and 22 | Wiseman, Sam and 23 | Gimpel, Kevin", 24 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 25 | month = may, 26 | year = "2022", 27 | address = "Dublin, Ireland", 28 | publisher = "Association for Computational Linguistics", 29 | url = "https://aclanthology.org/2022.acl-long.589", 30 | doi = "10.18653/v1/2022.acl-long.589", 31 | pages = "8602--8615" 32 | } 33 | 34 | @inproceedings{zhong2021qmsum, 35 | title = "{QMS}um: A New Benchmark for Query-based Multi-domain Meeting Summarization", 36 | author = "Zhong, Ming and 37 | Yin, Da and 38 | Yu, Tao and 39 | Zaidi, Ahmad and 40 | Mutuma, Mutethia and 41 | Jha, Rahul and 42 | Awadallah, Ahmed Hassan and 43 | Celikyilmaz, Asli and 44 | Liu, Yang and 45 | Qiu, Xipeng and 46 | Radev, Dragomir", 47 | booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", 48 | month = jun, 49 | year = "2021", 50 | address = "Online", 51 | publisher = "Association for Computational Linguistics", 52 | url = "https://aclanthology.org/2021.naacl-main.472", 53 | doi = "10.18653/v1/2021.naacl-main.472", 54 | pages = "5905--5921" 55 | } 56 | 57 | @inproceedings{dasigi2021qasper, 58 | title = "A Dataset of Information-Seeking Questions and Answers Anchored in Research Papers", 59 | author = "Dasigi, Pradeep and 60 | Lo, Kyle and 61 | Beltagy, Iz and 62 | Cohan, Arman and 63 | Smith, Noah A. and 64 | Gardner, Matt", 65 | booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", 66 | month = jun, 67 | year = "2021", 68 | address = "Online", 69 | publisher = "Association for Computational Linguistics", 70 | url = "https://aclanthology.org/2021.naacl-main.365", 71 | doi = "10.18653/v1/2021.naacl-main.365", 72 | pages = "4599--4610" 73 | } 74 | 75 | @article{kocisky2018narrativeqa, 76 | title = "The {N}arrative{QA} Reading Comprehension Challenge", 77 | author = "Ko{\v{c}}isk{\'y}, Tom{\'a}{\v{s}} and 78 | Schwarz, Jonathan and 79 | Blunsom, Phil and 80 | Dyer, Chris and 81 | Hermann, Karl Moritz and 82 | Melis, G{\'a}bor and 83 | Grefenstette, Edward", 84 | journal = "Transactions of the Association for Computational Linguistics", 85 | volume = "6", 86 | year = "2018", 87 | address = "Cambridge, MA", 88 | publisher = "MIT Press", 89 | url = "https://aclanthology.org/Q18-1023", 90 | doi = "10.1162/tacl_a_00023", 91 | pages = "317--328" 92 | } 93 | 94 | @inproceedings{pang-etal-2022-quality, 95 | title = "{Q}u{ALITY}: Question Answering with Long Input Texts, Yes!", 96 | author = "Pang, Richard Yuanzhe and 97 | Parrish, Alicia and 98 | Joshi, Nitish and 99 | Nangia, Nikita and 100 | Phang, Jason and 101 | Chen, Angelica and 102 | Padmakumar, Vishakh and 103 | Ma, Johnny and 104 | Thompson, Jana and 105 | He, He and 106 | Bowman, Samuel", 107 | booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies", 108 | month = jul, 109 | year = "2022", 110 | address = "Seattle, United States", 111 | publisher = "Association for Computational Linguistics", 112 | url = "https://aclanthology.org/2022.naacl-main.391", 113 | doi = "10.18653/v1/2022.naacl-main.391", 114 | pages = "5336--5358" 115 | } 116 | 117 | @inproceedings{shaham-etal-2022-scrolls, 118 | title = "{SCROLLS}: Standardized {C}ompa{R}ison Over Long Language Sequences", 119 | author = "Shaham, Uri and 120 | Segal, Elad and 121 | Ivgi, Maor and 122 | Efrat, Avia and 123 | Yoran, Ori and 124 | Haviv, Adi and 125 | Gupta, Ankit and 126 | Xiong, Wenhan and 127 | Geva, Mor and 128 | Berant, Jonathan and 129 | Levy, Omer", 130 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 131 | month = dec, 132 | year = "2022", 133 | address = "Abu Dhabi, United Arab Emirates", 134 | publisher = "Association for Computational Linguistics", 135 | url = "https://aclanthology.org/2022.emnlp-main.823", 136 | pages = "12007--12021" 137 | } 138 | 139 | @inproceedings{wang-etal-2022-squality, 140 | title = "{SQ}u{ALITY}: Building a Long-Document Summarization Dataset the Hard Way", 141 | author = "Wang, Alex and 142 | Pang, Richard Yuanzhe and 143 | Chen, Angelica and 144 | Phang, Jason and 145 | Bowman, Samuel R.", 146 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 147 | month = dec, 148 | year = "2022", 149 | address = "Abu Dhabi, United Arab Emirates", 150 | publisher = "Association for Computational Linguistics", 151 | url = "https://aclanthology.org/2022.emnlp-main.75", 152 | pages = "1139--1156" 153 | } 154 | 155 | @article{trivedi-etal-2022-musique, 156 | title = "♫ {M}u{S}i{Q}ue: Multihop Questions via Single-hop Question Composition", 157 | author = "Trivedi, Harsh and 158 | Balasubramanian, Niranjan and 159 | Khot, Tushar and 160 | Sabharwal, Ashish", 161 | journal = "Transactions of the Association for Computational Linguistics", 162 | volume = "10", 163 | year = "2022", 164 | address = "Cambridge, MA", 165 | publisher = "MIT Press", 166 | url = "https://aclanthology.org/2022.tacl-1.31", 167 | doi = "10.1162/tacl_a_00475", 168 | pages = "539--554" 169 | } 170 | 171 | @inproceedings{kryscinski-etal-2022-booksum, 172 | title = "{BOOKSUM}: A Collection of Datasets for Long-form Narrative Summarization", 173 | author = "Kryscinski, Wojciech and 174 | Rajani, Nazneen and 175 | Agarwal, Divyansh and 176 | Xiong, Caiming and 177 | Radev, Dragomir", 178 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2022", 179 | month = dec, 180 | year = "2022", 181 | address = "Abu Dhabi, United Arab Emirates", 182 | publisher = "Association for Computational Linguistics", 183 | url = "https://aclanthology.org/2022.findings-emnlp.488", 184 | pages = "6536--6558" 185 | } 186 | 187 | @article{angelidis-etal-2021-extractive, 188 | title = "Extractive Opinion Summarization in Quantized Transformer Spaces", 189 | author = "Angelidis, Stefanos and 190 | Amplayo, Reinald Kim and 191 | Suhara, Yoshihiko and 192 | Wang, Xiaolan and 193 | Lapata, Mirella", 194 | journal = "Transactions of the Association for Computational Linguistics", 195 | volume = "9", 196 | year = "2021", 197 | address = "Cambridge, MA", 198 | publisher = "MIT Press", 199 | url = "https://aclanthology.org/2021.tacl-1.17", 200 | doi = "10.1162/tacl_a_00366", 201 | pages = "277--293" 202 | } 203 | --------------------------------------------------------------------------------