├── api_secrets.py ├── experiment_saver.py ├── experiment_config.py ├── utils.py ├── viz_prompts.py ├── constants.py ├── .gitignore ├── README.md ├── main.py ├── prompts.py ├── analyze.py ├── models.py ├── LICENSE └── dataset_utils.py /api_secrets.py: -------------------------------------------------------------------------------- 1 | def get_api_key_by_name(name): 2 | return { 3 | "openai": "MY-OPENAI-KEY", # Add your key here 4 | "jurassic": "MY-JURASSIC-KEY" # Add your key here 5 | }[name] 6 | -------------------------------------------------------------------------------- /experiment_saver.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import pandas as pd 4 | 5 | 6 | class ExperimentSaver(defaultdict): 7 | 8 | def __init__(self, save_fname): 9 | super().__init__(list) 10 | self.save_fname = save_fname 11 | 12 | def save(self): 13 | print("Saving to", self.save_fname) 14 | pd.DataFrame(self).to_pickle(self.save_fname) 15 | -------------------------------------------------------------------------------- /experiment_config.py: -------------------------------------------------------------------------------- 1 | from constants import RESULTS_DIR_NAME 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class ExperimentConfig: 7 | ds_name: str 8 | model_name: str 9 | style_name: str 10 | n_shots: int 11 | do_strong_shuffle: bool 12 | do_perm: bool 13 | 14 | def get_save_fname(self): 15 | vals = [str(v) for v in vars(self).values()] 16 | return f"{RESULTS_DIR_NAME}/{'_'.join(vals)}.pkl" 17 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | 4 | 5 | def idx_to_ltr(idx): 6 | return chr(idx + ord("A")) 7 | 8 | 9 | def ltr_to_idx(ltr): 10 | return ord(ltr) - ord("A") 11 | 12 | 13 | def make_dir_if_does_not_exist(dir_name): 14 | if not os.path.exists(dir_name): 15 | os.makedirs(dir_name) 16 | 17 | 18 | def prep_openai_obj_for_save(obj, prompt_text=None): 19 | obj = dict(obj) 20 | for key in obj.keys(): 21 | if isinstance(obj[key], openai.openai_object.OpenAIObject): 22 | obj[key] = prep_openai_obj_for_save(obj[key]) 23 | if isinstance(obj[key], list): 24 | for i in range(len(obj[key])): 25 | if isinstance(obj[key][i], openai.openai_object.OpenAIObject): 26 | obj[key][i] = prep_openai_obj_for_save(obj[key][i]) 27 | if prompt_text is not None: 28 | obj["prompt_text"] = prompt_text 29 | return obj 30 | -------------------------------------------------------------------------------- /viz_prompts.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from dataset_utils import get_dataset_info, get_questions_with_exemplars 3 | 4 | import random 5 | 6 | 7 | def get_config_from_args(): 8 | parser = ArgumentParser() 9 | parser.add_argument("ds_name", help="Dataset name") 10 | parser.add_argument("style_name", help="Style name") 11 | parser.add_argument("n_shots", type=int, help="# of shots") 12 | parser.add_argument("--longest", action="store_true") 13 | args = parser.parse_args() 14 | return vars(args) 15 | 16 | 17 | def viz_prompts(ds_name, style_name, n_shots, longest): 18 | 19 | # Get questions with exemplars 20 | qwes = get_questions_with_exemplars( 21 | info=get_dataset_info(ds_name), 22 | n_shots=n_shots, 23 | do_strong_shuffle=False 24 | ) 25 | 26 | if style_name == "natural": 27 | prompt_texts = [q.get_natural_prompt() for q in qwes] 28 | elif style_name == "brown": 29 | prompt_texts = [q.get_brown_prompt() for q in qwes] 30 | 31 | if longest: 32 | p = max(prompt_texts, key=len) 33 | else: 34 | random.seed() 35 | p = random.choice(prompt_texts) 36 | print(p) 37 | 38 | 39 | if __name__ == "__main__": 40 | viz_prompts(**get_config_from_args()) 41 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | from utils import make_dir_if_does_not_exist 2 | 3 | 4 | CODEX_MODEL_NAME = "code-davinci-002" 5 | CP_MODEL_NAME = "codeparrot/codeparrot" 6 | GPT2_MODEL_NAME = "gpt2" 7 | GPT3_MODEL_NAME = "davinci" 8 | CURIE_MODEL_NAME = "text-curie-001" # Instruct 9 | INSTRUCT_MODEL_NAME = "text-davinci-002" 10 | JURASSIC_MODEL_NAME = "j1-jumbo" 11 | 12 | JURASSIC_SPACE = "▁" 13 | REPRODUCIBILITY_SEED = 0 14 | RETRY_SLEEP_TIME = 30 15 | SAVE_EVERY = 25 16 | 17 | MMLU_NAMES = [ 18 | "abstract_algebra", 19 | "anatomy", 20 | "astronomy", 21 | "business_ethics", 22 | "clinical_knowledge", 23 | "college_biology", 24 | "college_chemistry", 25 | "college_computer_science", 26 | "college_mathematics", 27 | "college_medicine", 28 | "college_physics", 29 | "computer_security", 30 | "conceptual_physics", 31 | "econometrics", 32 | "electrical_engineering", 33 | "elementary_mathematics", 34 | "formal_logic", 35 | "global_facts", 36 | "high_school_biology", 37 | "high_school_chemistry", 38 | "high_school_computer_science", 39 | "high_school_european_history", 40 | "high_school_geography", 41 | "high_school_government_and_politics", 42 | "high_school_macroeconomics", 43 | "high_school_mathematics", 44 | "high_school_microeconomics", 45 | "high_school_physics", 46 | "high_school_psychology", 47 | "high_school_statistics", 48 | "high_school_us_history", 49 | "high_school_world_history", 50 | "human_aging", 51 | "human_sexuality", 52 | "international_law", 53 | "jurisprudence", 54 | "logical_fallacies", 55 | "machine_learning", 56 | "management", 57 | "marketing", 58 | "medical_genetics", 59 | "miscellaneous", 60 | "moral_disputes", 61 | "moral_scenarios", 62 | "nutrition", 63 | "philosophy", 64 | "prehistory", 65 | "professional_accounting", 66 | "professional_law", 67 | "professional_medicine", 68 | "professional_psychology", 69 | "public_relations", 70 | "security_studies", 71 | "sociology", 72 | "us_foreign_policy", 73 | "virology", 74 | "world_religions" 75 | ] 76 | 77 | 78 | HF_CACHE_DIR_NAME = "hf_cache" 79 | RESULTS_DIR_NAME = "results" 80 | 81 | 82 | for dir_name in [HF_CACHE_DIR_NAME, RESULTS_DIR_NAME]: 83 | make_dir_if_does_not_exist(dir_name) 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # leveraging-llms-for-mcqa 2 | 3 | ## Overview 4 | This is the code for the ICLR 2023 paper "[Leveraging Large Language Models for Multiple Choice Question Answering](https://arxiv.org/abs/2210.12353)." It can be used to reproduce results in the paper and is designed to be extensible. 5 | 6 | ## Setup 7 | * Start by using your favorite package manager to install `datasets`, `numpy`, `openai`, `pandas`, `scipy`, `tqdm`, and `transformers`. 8 | * Now register your API keys in `api_sectrets.py`. To do this, add a key and value for each API key you want to register to the dictionary in the `get_api_key_by_name` function. You'll need an OpenAI key for OpenAI API experiments, and a Jurassic key for Jurassic API experiments. You can use the existing keys or choose your own names for the keys. 9 | 10 | ## Running Experiments 11 | To run experiments and reproduce the results from the paper you will use `main.py`. 12 | 13 | The positional command line arguments are: 14 | * The name of the dataset to use (must be a key from the dictionary inside `get_dataset_info` in `dataset_utils.py`) e.g., "mmlu" 15 | * The name of the model to use (must be a key in one of the dictionaries in `get_model_by_name` in `models.py`) e.g., "codex" 16 | * The name of the prompting style to use (either "brown" (called CP in the paper) or "natural" (called MCP in the paper) 17 | * The number of shots to use ("0" for zero-shot, "1" for one-shot, etc.) 18 | * The name of the API key to use (must be a key from the dictionary inside `get_api_key_by_name` in `api_secrets.py` 19 | 20 | The optional command line arguments are: 21 | * `--do_strong_shuffle`: For strong shuffling as used in Appendix C 22 | * `--do_perm`: For passing all permutations of each question to the model, as in the experiments in Section 4 23 | 24 | Running `main.py` will save a pickle file with experiment results. 25 | 26 | ## Analyzing Results 27 | To analyze the results of an experiment (from its saved pickle file) you will use `analyze.py`. The positional and optional command line arguments are the same except for you don't need to supply the name of an API key to use. These arguments will be used to look up the saved experiment pickle file. 28 | 29 | ## Other Functionality 30 | * You can **visualize prompts** that will be used by an experiment with `viz_prompts.py`. The positional command line arguments are dataset name, style name, and number of shots (as you'd use with `main.py`). The optional argument `--longest` will show the longest prompt instead of a random one. 31 | * You can **add a custom model** by adding a custom key and value to a dictionary in `get_model_by_name` within `models.py`. 32 | * You can **add a custom dataset** by adding a custom key and value to the dictionary in `get_dataset_info` within `dataset_utils.py`. 33 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from api_secrets import get_api_key_by_name 2 | from argparse import ArgumentParser 3 | from constants import SAVE_EVERY 4 | from dataset_utils import get_dataset_info, get_questions_with_exemplars 5 | from experiment_config import ExperimentConfig 6 | from experiment_saver import ExperimentSaver 7 | from itertools import permutations 8 | from models import get_model_by_name 9 | from tqdm import tqdm 10 | 11 | import copy 12 | 13 | 14 | def get_config_and_api_key_name_from_args(): 15 | parser = ArgumentParser() 16 | parser.add_argument("ds_name", help="Dataset name") 17 | parser.add_argument("model_name", help="Model name") 18 | parser.add_argument("style_name", help="Style name") 19 | parser.add_argument("n_shots", type=int, help="# of shots") 20 | parser.add_argument("api_key_name", help="API key name") 21 | parser.add_argument( 22 | "--do_strong_shuffle", 23 | action="store_true", 24 | help="Force correct answer index to change for each example" 25 | ) 26 | parser.add_argument( 27 | "--do_perm", 28 | action="store_true", 29 | help="Process every example with all possible answer orderings" 30 | ) 31 | args = parser.parse_args() 32 | api_key_name = args.api_key_name 33 | args = vars(args) 34 | del args["api_key_name"] 35 | return ExperimentConfig(**args), api_key_name 36 | 37 | 38 | def run_experiment(config, api_key_name): 39 | 40 | # Get API key 41 | api_key = get_api_key_by_name(name=api_key_name) 42 | 43 | # Load model 44 | model = get_model_by_name( 45 | name=config.model_name, 46 | api_key=api_key 47 | ) 48 | model = { 49 | "natural": model.process_question_natural, 50 | "brown": model.process_question_brown 51 | }[config.style_name] 52 | 53 | # Get questions with exemplars 54 | qwes = get_questions_with_exemplars( 55 | info=get_dataset_info(config.ds_name), 56 | n_shots=config.n_shots, 57 | do_strong_shuffle=config.do_strong_shuffle 58 | ) 59 | 60 | # Run experiment, saving results 61 | saver = ExperimentSaver(save_fname=config.get_save_fname()) 62 | for q_idx, qwe in enumerate(tqdm(qwes)): 63 | 64 | if config.do_perm: 65 | for perm_order in permutations(range(qwe.get_n_choices())): 66 | qwe_copy = copy.deepcopy(qwe) 67 | qwe_copy.permute_choices(perm_order) 68 | response = model(qwe_copy) 69 | saver["question_idx"].append(q_idx) 70 | saver["perm_order"].append(perm_order) 71 | saver["qwe"].append(vars(qwe_copy)) 72 | saver["model_response"].append(vars(response)) 73 | 74 | # When doing permutations we ignore SAVE_EVERY and 75 | # save after every question 76 | saver.save() 77 | else: 78 | response = model(qwe) 79 | 80 | saver["question_idx"].append(q_idx) 81 | if qwe.task is not None: 82 | saver["task"].append(qwe.task) 83 | saver["qwe"].append(vars(qwe)) 84 | saver["model_response"].append(vars(response)) 85 | 86 | if q_idx % SAVE_EVERY == 0 and q_idx != 0: 87 | saver.save() 88 | 89 | saver.save() 90 | 91 | 92 | if __name__ == "__main__": 93 | run_experiment(*get_config_and_api_key_name_from_args()) 94 | -------------------------------------------------------------------------------- /prompts.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from utils import idx_to_ltr 3 | 4 | import random 5 | 6 | 7 | @dataclass 8 | class QuestionPart: 9 | text: str 10 | tag: str = None 11 | 12 | def __str__(self): 13 | if self.tag is not None: 14 | return f"{self.tag}: {self.text}" 15 | else: 16 | return self.text 17 | 18 | 19 | @dataclass 20 | class Question: 21 | parts: list 22 | choices: list 23 | answer_idx: int 24 | task: str = None 25 | 26 | def get_n_choices(self): 27 | return len(self.choices) 28 | 29 | def get_answer_str(self): 30 | return self.choices[self.answer_idx] 31 | 32 | def _get_prompt(self, include_choices): 33 | prompt = "" 34 | for part in self.parts: 35 | prompt += f"{str(part)}\n" 36 | if include_choices: 37 | for i, choice in enumerate(self.choices): 38 | prompt += f"{idx_to_ltr(i)}. {choice}\n" 39 | return prompt + "Answer:" 40 | 41 | def get_natural_prompt(self): 42 | return self._get_prompt(include_choices=True) 43 | 44 | def get_brown_prompt(self): 45 | return self._get_prompt(include_choices=False) 46 | 47 | def strong_shuffle(self): 48 | # This method shuffles choices such that choosing 49 | # the answer at the originally correct 50 | # index will mean getting the question wrong 51 | 52 | # For degenerate questions where all choices are the same 53 | if len(set(self.choices)) == 1: 54 | return 55 | 56 | answer_idx = self.answer_idx 57 | answer_str = self.get_answer_str() 58 | while self.choices[answer_idx] == answer_str: 59 | random.shuffle(self.choices) 60 | self.answer_idx = self.choices.index(answer_str) 61 | 62 | def permute_choices(self, perm): 63 | self.choices = [self.choices[i] for i in perm] 64 | self.answer_idx = perm.index(self.answer_idx) 65 | 66 | 67 | class QuestionWithExemplars(Question): 68 | 69 | def __init__(self, parts, choices, answer_idx, exemplars, task=None): 70 | super().__init__(parts, choices, answer_idx, task) 71 | self.exemplars = exemplars 72 | 73 | def get_natural_prompt(self): 74 | prompt = super().get_natural_prompt() 75 | if len(self.exemplars): 76 | exemplar_prompts = [e.get_natural_prompt() for e in self.exemplars] 77 | exemplars = "\n\n".join(exemplar_prompts) 78 | return f"{exemplars}\n\n{prompt}" 79 | else: 80 | return prompt 81 | 82 | def get_brown_prompt(self): 83 | prompt = super().get_brown_prompt() 84 | if len(self.exemplars): 85 | exemplar_prompts = [e.get_brown_prompt() for e in self.exemplars] 86 | exemplars = "\n\n".join(exemplar_prompts) 87 | return f"{exemplars}\n\n{prompt}" 88 | else: 89 | return prompt 90 | 91 | 92 | class Exemplar(Question): 93 | 94 | def get_natural_prompt(self): 95 | prompt = super().get_natural_prompt() 96 | answer_ltr = idx_to_ltr(self.answer_idx) 97 | return f"{prompt} {answer_ltr}" 98 | 99 | def get_brown_prompt(self): 100 | prompt = super().get_brown_prompt() 101 | return f"{prompt} {self.get_answer_str()}" 102 | -------------------------------------------------------------------------------- /analyze.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from experiment_config import ExperimentConfig 3 | from scipy import stats 4 | from utils import idx_to_ltr 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | def get_config_from_args(): 11 | parser = ArgumentParser() 12 | parser.add_argument("ds_name", help="Dataset name") 13 | parser.add_argument("model_name", help="Model name") 14 | parser.add_argument("style_name", help="Style name") 15 | parser.add_argument("n_shots", type=int, help="# of shots") 16 | parser.add_argument( 17 | "--do_strong_shuffle", 18 | action="store_true", 19 | help="Force correct answer index to change for each example" 20 | ) 21 | parser.add_argument( 22 | "--do_perm", 23 | action="store_true", 24 | help="Process every example with all possible answer orderings" 25 | ) 26 | args = parser.parse_args() 27 | return ExperimentConfig(**vars(args)) 28 | 29 | 30 | def add_correct_answer_col(df): 31 | df["correct_answer"] = df.apply( 32 | lambda row: idx_to_ltr(row["qwe"]["answer_idx"]), 33 | axis=1 34 | ) 35 | 36 | 37 | def div_dicts(a, b): 38 | # Divide each value in dictionary a by the matching 39 | # value in dictionary b 40 | new_dict = dict() 41 | for key in a.keys(): 42 | if key in b.keys(): 43 | new_dict[key] = a[key] / b[key] 44 | return new_dict 45 | 46 | 47 | def sub_dicts(a, b): 48 | # Subtract from each value in dictionary a the matching 49 | # value in dictionary b 50 | new_dict = dict() 51 | for key in a.keys(): 52 | if key in b.keys(): 53 | new_dict[key] = a[key] - b[key] 54 | return new_dict 55 | 56 | 57 | def analyze_results(config): 58 | # Get file name of experiment to load 59 | fname = config.get_save_fname() 60 | 61 | # Load file 62 | df = pd.read_pickle(fname) 63 | 64 | if config.style_name == "natural": 65 | if config.do_perm: 66 | # We start by calculating the logprob of each 67 | # answer option irrespective of the order the 68 | # options were presented in 69 | def get_lp(ltr, lps): 70 | if f"Ġ{ltr}" in lps.keys(): 71 | return lps[f"Ġ{ltr}"] 72 | elif f" {ltr}" in lps.keys(): 73 | return lps[f" {ltr}"] 74 | else: 75 | return -np.inf 76 | 77 | df["ord_lps"] = df.apply( 78 | lambda row: [ 79 | get_lp( 80 | idx_to_ltr(row['perm_order'].index(i)).upper(), 81 | row["model_response"]["logprobs"] 82 | ) for i in range(len(row["perm_order"]))], 83 | axis=1 84 | ) 85 | 86 | df["coverage"] = df.apply( 87 | lambda row: np.sum(np.exp(row["ord_lps"])), 88 | axis=1 89 | ) 90 | print(f"Coverage: {df['coverage'].mean()}") 91 | 92 | # Add a column for if model got question right 93 | df["correct"] = df.apply( 94 | lambda row: max( 95 | row["model_response"]["logprobs"].items(), 96 | key=lambda x: x[1] 97 | # In line below [0] is the key (as opposed to value) 98 | # Additionally we use 1: instead of lstrip because 99 | # we want the prediction "A" to be wrong when " A" 100 | # is expected, for example 101 | )[0][1:] == idx_to_ltr(row["qwe"]["answer_idx"]), 102 | axis=1 103 | ) 104 | print(f"Accuracy: {df['correct'].mean()}") 105 | 106 | # Making lists of lists 107 | grouped = df.groupby("question_idx")["ord_lps"].apply(list) 108 | lps_by_question = grouped.tolist() 109 | 110 | # HOW MANY OF THE CHOSEN ANSWERS MATCH THE MAJORITY 111 | # ANSWER? 112 | props = list() 113 | for q_lps in lps_by_question: 114 | majority_choice = stats.mode( 115 | [np.argmax(x) for x in q_lps] 116 | )[0][0] 117 | props.append( 118 | sum( 119 | [np.argmax(x) == majority_choice for x in q_lps] 120 | ) / len(q_lps) 121 | ) 122 | 123 | print("PPA:", np.mean(props)) 124 | 125 | else: 126 | add_correct_answer_col(df) 127 | df["chosen_answer_raw"] = df.apply( 128 | lambda row: max( 129 | row["model_response"]["logprobs"].items(), 130 | key=lambda x: x[1] 131 | # In line below [0] is the key (as opposed to value) 132 | # Additionally we use 1: instead of lstrip because 133 | # we want the prediction "A" to be wrong when " A" 134 | # is expected, for example 135 | )[0][1:], 136 | axis=1 137 | ) 138 | 139 | df["correct"] = df.apply( 140 | lambda row: row["chosen_answer_raw"] == row["correct_answer"], 141 | axis=1 142 | ) 143 | 144 | print( 145 | "Accuracy:", 146 | df["correct"].mean() 147 | ) 148 | 149 | # If config.ds_name == "mmlu" we'll present accuracy 150 | # after grouping by "task" 151 | if config.ds_name == "mmlu": 152 | print("Accuracy by task:") 153 | g = df.groupby("task")["correct"].mean() 154 | for i, task_name in enumerate(g.index): 155 | print(task_name, round(g[i]*100, 1)) 156 | else: 157 | add_correct_answer_col(df) 158 | df["chosen_answer_raw"] = df.apply( 159 | lambda row: max( 160 | row["model_response"]["logprobs"].items(), 161 | key=lambda x: x[1] 162 | # In line below [0] is the key (as opposed to value) 163 | # No need for 1: here because we assign the letters 164 | # manually in models.py 165 | )[0], 166 | axis=1 167 | ) 168 | print( 169 | "Accuracy (raw):", 170 | (df["chosen_answer_raw"] == df["correct_answer"]).mean() 171 | ) 172 | 173 | # Answer with length normalization 174 | df["chosen_answer_ln"] = df.apply( 175 | lambda row: max( 176 | div_dicts( 177 | row["model_response"]["logprobs"], 178 | row["model_response"]["lens"] 179 | ).items(), 180 | key=lambda x: x[1] 181 | )[0], 182 | axis=1 183 | ) 184 | print( 185 | "Accuracy (length-normalized):", 186 | (df["chosen_answer_ln"] == df["correct_answer"]).mean() 187 | ) 188 | 189 | # Answer with special normalization 190 | df["chosen_answer_sn"] = df.apply( 191 | lambda row: max( 192 | sub_dicts( 193 | row["model_response"]["logprobs"], 194 | row["model_response"]["unconditional_logprobs"] 195 | ).items(), 196 | key=lambda x: x[1] 197 | )[0], 198 | axis=1 199 | ) 200 | print( 201 | "Accuracy (unconditional-normalized):", 202 | (df["chosen_answer_sn"] == df["correct_answer"]).mean() 203 | ) 204 | 205 | 206 | if __name__ == "__main__": 207 | analyze_results(get_config_from_args()) 208 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from constants import ( 2 | CODEX_MODEL_NAME, 3 | CP_MODEL_NAME, 4 | CURIE_MODEL_NAME, 5 | GPT2_MODEL_NAME, 6 | GPT3_MODEL_NAME, 7 | HF_CACHE_DIR_NAME, 8 | INSTRUCT_MODEL_NAME, 9 | JURASSIC_MODEL_NAME, 10 | JURASSIC_SPACE, 11 | RETRY_SLEEP_TIME 12 | ) 13 | from dataclasses import dataclass 14 | from transformers import AutoTokenizer, AutoModelForCausalLM 15 | from utils import idx_to_ltr, prep_openai_obj_for_save 16 | 17 | import numpy as np 18 | import openai 19 | import requests 20 | import time 21 | 22 | 23 | @dataclass 24 | class ModelResponseNatural: 25 | logprobs: dict 26 | response_list: list 27 | 28 | 29 | @dataclass 30 | class ModelResponseBrown: 31 | logprobs: dict 32 | unconditional_logprobs: dict 33 | lens: dict 34 | response_list: list 35 | 36 | 37 | class Test: 38 | 39 | def _get_uniform_response(self, n_choices): 40 | return {idx_to_ltr(i): np.log(1/n_choices) for i in range(n_choices)} 41 | 42 | def process_question_natural(self, question): 43 | n_choices = question.get_n_choices() 44 | logprobs = self._get_uniform_response(n_choices=n_choices) 45 | return ModelResponseNatural( 46 | logprobs=logprobs, 47 | response_list=list() 48 | ) 49 | 50 | def process_question_brown(self, question): 51 | n_choices = question.get_n_choices() 52 | logprobs = self._get_uniform_response(n_choices=n_choices) 53 | lens = {idx_to_ltr(i): 1 for i in range(n_choices)} 54 | return ModelResponseBrown( 55 | logprobs=logprobs, 56 | unconditional_logprobs=logprobs, 57 | lens=lens, 58 | response_list=list() 59 | ) 60 | 61 | 62 | class GPT2Model: 63 | 64 | def __init__(self, model_name): 65 | self.tokenizer = AutoTokenizer.from_pretrained( 66 | model_name, 67 | cache_dir=HF_CACHE_DIR_NAME 68 | ) 69 | self.model = AutoModelForCausalLM.from_pretrained( 70 | model_name, 71 | cache_dir=HF_CACHE_DIR_NAME 72 | ) 73 | self.lbls_map = {v: k for k, v in self.tokenizer.vocab.items()} 74 | 75 | def process_question_natural(self, question): 76 | prompt_text = question.get_natural_prompt() 77 | inputs = self.tokenizer(prompt_text, return_tensors="pt") 78 | outputs = self.model(**inputs) 79 | logits = outputs.logits[0, -1] 80 | probs = logits.softmax(dim=-1) 81 | logprobs_dict = { 82 | self.lbls_map[i]: 83 | np.log(probs[i].item()) for i in range(len(self.lbls_map)) 84 | } 85 | 86 | # Reduce logprobs_dict to only keys with top 50 largest values 87 | logprobs_dict = { 88 | k: v for k, v in sorted( 89 | logprobs_dict.items(), 90 | key=lambda item: item[1], 91 | reverse=True 92 | )[:200] 93 | } 94 | 95 | return ModelResponseNatural( 96 | logprobs=logprobs_dict, 97 | response_list=list() 98 | ) 99 | 100 | def process_question_brown(self): 101 | pass 102 | 103 | 104 | class CodeParrot(GPT2Model): 105 | 106 | def __init__(self): 107 | super().__init__(model_name=CP_MODEL_NAME) 108 | 109 | 110 | class GPT2(GPT2Model): 111 | 112 | def __init__(self): 113 | super().__init__(model_name=GPT2_MODEL_NAME) 114 | 115 | 116 | class Jurassic: 117 | 118 | def __init__(self, api_key): 119 | self.key = api_key 120 | 121 | def process_question_natural(self, question): 122 | prompt_text = question.get_natural_prompt() 123 | 124 | response = requests.post( 125 | f"https://api.ai21.com/studio/v1/{JURASSIC_MODEL_NAME}/complete", 126 | headers={"Authorization": f"Bearer {self.key}"}, 127 | json={ 128 | "prompt": prompt_text, 129 | "numResults": 1, 130 | "maxTokens": 1, 131 | "topKReturn": 64, 132 | "temperature": 1.0, 133 | } 134 | ) 135 | 136 | while True: 137 | resp_json = response.json() 138 | try: 139 | completion_tokens = resp_json["completions"][0]["data"]["tokens"][0]["topTokens"] 140 | break 141 | except Exception: 142 | print(resp_json) 143 | print(f"Will retry API call in {RETRY_SLEEP_TIME} seconds...") 144 | time.sleep(RETRY_SLEEP_TIME) 145 | 146 | log_probs = {t["token"]: t["logprob"] for t in completion_tokens} 147 | completion_tokens = {k.replace(JURASSIC_SPACE, " "): v 148 | for k, v in log_probs.items()} 149 | return ModelResponseNatural( 150 | logprobs=completion_tokens, 151 | response_list=[resp_json] 152 | ) 153 | 154 | def process_question_brown(self): 155 | pass 156 | 157 | 158 | class OpenAIModel: 159 | 160 | def __init__(self, api_key, model_name, add_space=False): 161 | openai.api_key = api_key 162 | self.add_space = add_space 163 | self.model_name = model_name 164 | 165 | def process_question_natural(self, question): 166 | prompt_text = question.get_natural_prompt() 167 | response = self._get_response(text=prompt_text, echo=False) 168 | logprobs = dict(response["choices"][0]["logprobs"]["top_logprobs"][0]) 169 | 170 | return ModelResponseNatural( 171 | logprobs=logprobs, 172 | response_list=[ 173 | prep_openai_obj_for_save( 174 | obj=response, 175 | prompt_text=prompt_text 176 | ) 177 | ] 178 | ) 179 | 180 | def _get_response(self, text, echo): 181 | while True: 182 | try: 183 | response = openai.Completion.create( 184 | model=self.model_name, 185 | prompt=text+(" " if self.add_space else ""), 186 | temperature=0, # Doesn't actually matter here 187 | max_tokens=1, # Just need to get letter 188 | logprobs=5, # Get max number of logprobs 189 | echo=echo 190 | ) 191 | return response 192 | except Exception as e: 193 | print(e) 194 | print("Will wait and retry...") 195 | time.sleep(RETRY_SLEEP_TIME) 196 | 197 | def process_question_brown(self, question): 198 | prompt_text = question.get_brown_prompt() 199 | 200 | response_list = list() 201 | logprobs = dict() 202 | unconditional_logprobs = dict() 203 | lens = dict() 204 | 205 | for idx, choice in enumerate(question.choices): 206 | ltr = idx_to_ltr(idx) 207 | 208 | # Get unconditional logprobs 209 | response = self._get_response(text=f"Answer: {choice}", echo=True) 210 | choice_logprobs = ( 211 | response["choices"][0]["logprobs"]["token_logprobs"][2:-1] 212 | ) 213 | 214 | choice_n_tokens = len(choice_logprobs) 215 | unconditional_logprobs[ltr] = sum(choice_logprobs) 216 | lens[ltr] = choice_n_tokens 217 | response_list.append( 218 | prep_openai_obj_for_save( 219 | obj=response, 220 | prompt_text=f"Answer: {choice}" 221 | ) 222 | ) 223 | 224 | # Get conditional logprobs 225 | response = self._get_response( 226 | text=f"{prompt_text} {choice}", echo=True 227 | ) 228 | token_logprobs = ( 229 | response["choices"][0]["logprobs"]["token_logprobs"] 230 | ) 231 | choice_logprobs = token_logprobs[-(choice_n_tokens+1):-1] 232 | 233 | logprobs[ltr] = sum(choice_logprobs) 234 | response_list.append( 235 | prep_openai_obj_for_save( 236 | obj=response, 237 | prompt_text=f"{prompt_text} {choice}" 238 | ) 239 | ) 240 | 241 | return ModelResponseBrown( 242 | logprobs=logprobs, 243 | unconditional_logprobs=unconditional_logprobs, 244 | lens=lens, 245 | response_list=response_list 246 | ) 247 | 248 | 249 | class Codex(OpenAIModel): 250 | 251 | def __init__(self, api_key): 252 | super().__init__( 253 | api_key=api_key, 254 | model_name=CODEX_MODEL_NAME, 255 | add_space=True 256 | ) 257 | 258 | 259 | class GPT3(OpenAIModel): 260 | 261 | def __init__(self, api_key): 262 | super().__init__(api_key=api_key, model_name=GPT3_MODEL_NAME) 263 | 264 | 265 | class Instruct(OpenAIModel): 266 | 267 | def __init__(self, api_key): 268 | super().__init__(api_key=api_key, model_name=INSTRUCT_MODEL_NAME) 269 | 270 | 271 | class Curie(OpenAIModel): 272 | 273 | def __init__(self, api_key): 274 | super().__init__(api_key=api_key, model_name=CURIE_MODEL_NAME) 275 | 276 | 277 | def get_model_by_name(name, api_key): 278 | try: 279 | return { 280 | "codex": Codex, 281 | "gpt3": GPT3, 282 | "instruct": Instruct, 283 | "curie": Curie, 284 | "jurassic": Jurassic 285 | }[name](api_key=api_key) 286 | except KeyError: 287 | return { 288 | "test": Test, 289 | "cp": CodeParrot, 290 | "gpt2": GPT2 291 | }[name]() 292 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /dataset_utils.py: -------------------------------------------------------------------------------- 1 | from constants import HF_CACHE_DIR_NAME, MMLU_NAMES, REPRODUCIBILITY_SEED 2 | from dataclasses import dataclass 3 | from datasets import load_dataset 4 | from itertools import groupby 5 | from prompts import Exemplar, QuestionPart, QuestionWithExemplars 6 | from typing import Callable 7 | from utils import idx_to_ltr, ltr_to_idx 8 | 9 | import random 10 | 11 | 12 | @dataclass 13 | class DatasetInfo: 14 | path: str 15 | exemplar_split: str 16 | eval_split: str 17 | extractor: Callable 18 | name: str = None 19 | data_dir: str = None 20 | 21 | 22 | def load_hf_dataset(path, name, data_dir, split): 23 | if split.endswith(".jsonl"): 24 | # This is for Social IQa test sets 25 | return load_dataset("json", data_files=split)["train"] 26 | else: 27 | return load_dataset( 28 | path=path, 29 | name=name, 30 | data_dir=data_dir, 31 | split=split, 32 | cache_dir=HF_CACHE_DIR_NAME 33 | ) 34 | 35 | 36 | def load_hf_dataset_no_verify(path, name, data_dir, split): 37 | return load_dataset( 38 | path=path, 39 | name=name, 40 | data_dir=data_dir, 41 | split=split, 42 | cache_dir=HF_CACHE_DIR_NAME, 43 | ignore_verifications=True 44 | ) 45 | 46 | 47 | def get_questions_with_exemplars( 48 | info, 49 | n_shots, 50 | do_strong_shuffle, 51 | load_fn=load_hf_dataset 52 | ): 53 | 54 | # If ds_info is a function that tells us that the dataset 55 | # should be loaded using that custom function 56 | if callable(info): 57 | return info(n_shots=n_shots, do_strong_shuffle=do_strong_shuffle) 58 | 59 | # Create exemplars 60 | exemplar_ds = load_fn( 61 | path=info.path, 62 | name=info.name, 63 | data_dir=info.data_dir, 64 | split=info.exemplar_split 65 | ) 66 | exemplars = [Exemplar(**info.extractor(row)) for row in exemplar_ds] 67 | random.seed(REPRODUCIBILITY_SEED) 68 | if do_strong_shuffle: 69 | for exemplar in exemplars: 70 | exemplar.strong_shuffle() 71 | 72 | # Create questions with exemplars 73 | eval_ds = load_fn( 74 | path=info.path, 75 | name=info.name, 76 | data_dir=info.data_dir, 77 | split=info.eval_split 78 | ) 79 | 80 | random.seed(REPRODUCIBILITY_SEED) 81 | qwes = list() 82 | for row_idx, row in enumerate(eval_ds): 83 | 84 | # Choose some random exemplars - we are careful here 85 | # to avoid choosing an exemplar that is the same as 86 | # the question 87 | if info.exemplar_split == info.eval_split: 88 | possible_idxs = [i for i in range(len(exemplars)) if i != row_idx] 89 | else: 90 | possible_idxs = list(range(len(exemplars))) 91 | row_exemplars = [ 92 | exemplars[i] for i in random.sample(possible_idxs, n_shots) 93 | ] 94 | 95 | row_qwe = QuestionWithExemplars( 96 | **{**info.extractor(row), **{"exemplars": row_exemplars}} 97 | ) 98 | qwes.append(row_qwe) 99 | random.seed(REPRODUCIBILITY_SEED) 100 | if do_strong_shuffle: 101 | for qwe in qwes: 102 | qwe.strong_shuffle() 103 | 104 | return qwes 105 | 106 | 107 | def load_tiny_obqa(n_shots, do_strong_shuffle): 108 | qwes = get_questions_with_exemplars( 109 | info=get_dataset_info("obqa"), 110 | n_shots=n_shots, 111 | do_strong_shuffle=do_strong_shuffle 112 | ) 113 | random.seed(REPRODUCIBILITY_SEED) 114 | return random.sample(qwes, 100) 115 | 116 | 117 | def load_mini_rm(n_shots, do_strong_shuffle): 118 | qwes = get_questions_with_exemplars( 119 | info=get_dataset_info("rm"), 120 | n_shots=n_shots, 121 | do_strong_shuffle=do_strong_shuffle 122 | ) 123 | random.seed(REPRODUCIBILITY_SEED) 124 | return random.sample(qwes, 500) 125 | 126 | 127 | def load_mini_sc(n_shots, do_strong_shuffle): 128 | qwes = get_questions_with_exemplars( 129 | info=get_dataset_info("sc"), 130 | n_shots=n_shots, 131 | do_strong_shuffle=do_strong_shuffle 132 | ) 133 | random.seed(REPRODUCIBILITY_SEED) 134 | return random.sample(qwes, 500) 135 | 136 | 137 | def read_lqa(path, name, data_dir, split): 138 | 139 | with open(f"logiqa_{split}.txt", "r") as f: 140 | lines = f.readlines() 141 | 142 | grouper = groupby(lines, key=lambda x: x in {"\n"}) 143 | ds = dict(enumerate((list(j) for i, j in grouper if not i), 1)).values() 144 | 145 | formatted_ds = list() 146 | for row in ds: 147 | formatted_row = dict() 148 | formatted_row["answer_idx"] = ord(row[0][0]) - ord("a") 149 | formatted_row["question"] = f"{row[1].strip()} {row[2].strip()}" 150 | choices = list() 151 | for choice in row[3:]: 152 | choices.append(choice[2:].strip()) 153 | formatted_row["choices"] = choices 154 | formatted_ds.append(formatted_row) 155 | return formatted_ds 156 | 157 | 158 | def load_lqa(n_shots, do_strong_shuffle): 159 | return get_questions_with_exemplars( 160 | info=DatasetInfo( 161 | path=None, 162 | exemplar_split="train", 163 | eval_split="test", 164 | extractor=lambda row: { 165 | "parts": [ 166 | QuestionPart(text=row["question"], tag="Question") 167 | ], 168 | "choices": row["choices"], 169 | "answer_idx": row["answer_idx"] 170 | } 171 | ), 172 | n_shots=n_shots, 173 | do_strong_shuffle=do_strong_shuffle, 174 | load_fn=read_lqa 175 | ) 176 | 177 | 178 | def load_mmlu(n_shots, do_strong_shuffle): 179 | all_qwes = list() 180 | for name in MMLU_NAMES: 181 | name_qwes = get_questions_with_exemplars( 182 | info=DatasetInfo( 183 | path="hendrycks_test", 184 | name=name, 185 | exemplar_split="dev", 186 | eval_split="test", 187 | extractor=lambda row: { 188 | "parts": [ 189 | QuestionPart(row["question"], tag="Question") 190 | ], 191 | "choices": row["choices"], 192 | "answer_idx": ( 193 | row["answer"] 194 | if isinstance(row["answer"], int) 195 | else ltr_to_idx(row["answer"]) 196 | ), 197 | "task": name 198 | } 199 | ), 200 | n_shots=n_shots, 201 | do_strong_shuffle=do_strong_shuffle, 202 | load_fn=load_hf_dataset_no_verify 203 | ) 204 | all_qwes.extend(name_qwes) 205 | return all_qwes 206 | 207 | 208 | def rm_final_period(text): 209 | return text[:-1] if text.endswith(".") else text 210 | 211 | 212 | def get_anli_dataset_info(round): 213 | return DatasetInfo( 214 | path="anli", 215 | exemplar_split=f"train_r{round}", 216 | eval_split=f"test_r{round}", 217 | extractor=lambda row: { 218 | "parts": [ 219 | QuestionPart(text=row["premise"], tag="Premise"), 220 | QuestionPart(text=row["hypothesis"], tag="Hypothesis") 221 | ], 222 | "choices": [ 223 | "Hypothesis is definitely true given premise", 224 | "Hypothesis might be true given premise", 225 | "Hypothesis is definitely not true given premise" 226 | ], 227 | "answer_idx": row["label"] 228 | } 229 | ) 230 | 231 | 232 | def get_anli_shuffled_dataset_info(round): 233 | return DatasetInfo( 234 | path="anli", 235 | exemplar_split=f"train_r{round}", 236 | eval_split=f"test_r{round}", 237 | extractor=lambda row: { 238 | "parts": [ 239 | QuestionPart(text=row["premise"], tag="Premise"), 240 | QuestionPart(text=row["hypothesis"], tag="Hypothesis") 241 | ], 242 | "choices": [ 243 | "Hypothesis is definitely not true given premise", 244 | "Hypothesis is definitely true given premise", 245 | "Hypothesis might be true given premise" 246 | ], 247 | "answer_idx": {0: 1, 1: 2, 2: 0}[row["label"]] 248 | } 249 | ) 250 | 251 | 252 | def get_csqa_dataset_info(test): 253 | return DatasetInfo( 254 | path="commonsense_qa", 255 | exemplar_split="train", 256 | eval_split="test" if test else "validation", 257 | extractor=lambda row: { 258 | "parts": [ 259 | QuestionPart(text=row["question"], tag="Question") 260 | ], 261 | "choices": row["choices"]["text"], 262 | "answer_idx": None if row["answerKey"] == "" else ( 263 | row["choices"]["label"].index(row["answerKey"]) 264 | ) 265 | } 266 | ) 267 | 268 | 269 | def get_siqa_dataset_info(test): 270 | return DatasetInfo( 271 | path="social_i_qa", 272 | exemplar_split="train", 273 | eval_split="socialiqa.jsonl" if test else "validation", 274 | extractor=lambda row: { 275 | "parts": [ 276 | QuestionPart( 277 | text=f"{row['context']} {row['question']}", 278 | tag="Question" 279 | ) 280 | ], 281 | "choices": [row[f"answer{idx_to_ltr(i)}"] for i in range(3)], 282 | "answer_idx": ( 283 | int(row["label"]) - 1 if "label" in row.keys() else None 284 | ) 285 | } 286 | ) 287 | 288 | 289 | def get_copa_dataset_info(test): 290 | return DatasetInfo( 291 | path="super_glue", 292 | name="copa", 293 | exemplar_split="train", 294 | eval_split="test" if test else "validation", 295 | extractor=lambda row: { 296 | "parts": [ 297 | QuestionPart( 298 | text=rm_final_period(row["premise"]) + ( 299 | " because" if row["question"] == "cause" else " so" 300 | ), 301 | tag="Question" 302 | ) 303 | ], 304 | "choices": [row[f"choice{i+1}"] for i in range(2)], 305 | "answer_idx": row["label"] 306 | } 307 | ) 308 | 309 | 310 | def get_piqa_dataset_info(test): 311 | return DatasetInfo( 312 | path="piqa", 313 | exemplar_split="train", 314 | eval_split="test" if test else "validation", 315 | extractor=lambda row: { 316 | "parts": [ 317 | QuestionPart(row["goal"], tag="Question") 318 | ], 319 | "choices": [row[f"sol{i+1}"] for i in range(2)], 320 | "answer_idx": row["label"] 321 | } 322 | ) 323 | 324 | 325 | def get_cqa_dataset_info(test): 326 | return DatasetInfo( 327 | path="cosmos_qa", 328 | exemplar_split="train", 329 | eval_split="test" if test else "validation", 330 | extractor=lambda row: { 331 | "parts": [ 332 | QuestionPart( 333 | text=row["context"], 334 | tag="Passage" 335 | ), 336 | QuestionPart( 337 | text=row["question"], 338 | tag="Question" 339 | ) 340 | ], 341 | "choices": [row[f"answer{i}"] for i in range(4)], 342 | "answer_idx": row["label"] 343 | } 344 | ) 345 | 346 | 347 | def get_figqa_dataset_info(test): 348 | return DatasetInfo( 349 | path="nightingal3/fig-qa", 350 | exemplar_split="train", 351 | eval_split="test" if test else "validation", 352 | extractor=lambda row: { 353 | "parts": [ 354 | QuestionPart( 355 | text=f"{rm_final_period(row['startphrase'])}, meaning", 356 | tag="Question" 357 | ) 358 | ], 359 | "choices": [row[f"ending{i+1}"] for i in range(2)], 360 | "answer_idx": row["labels"] 361 | } 362 | ) 363 | 364 | 365 | def get_hs_dataset_info(test): 366 | 367 | return DatasetInfo( 368 | path="hellaswag", 369 | exemplar_split="train", 370 | eval_split="test" if test else "validation", 371 | extractor=lambda row: { 372 | "parts": [ 373 | QuestionPart( 374 | ( 375 | f"({row['activity_label']}) " if 376 | row["source_id"].startswith("activity") 377 | else "" 378 | ) + row["ctx_a"], 379 | tag="Passage" 380 | ), 381 | QuestionPart( 382 | "Which choice best continues the passage?", 383 | tag="Question" 384 | ) 385 | ], 386 | "choices": [ 387 | f"{row['ctx_b']}{' ' if len(row['ctx_b']) else ''}{e}" 388 | for e in row["endings"] 389 | ], 390 | "answer_idx": int(row["label"]) if len(row["label"]) else None 391 | } 392 | ) 393 | 394 | 395 | def get_medmcqa_dataset_info(test): 396 | return DatasetInfo( 397 | path="medmcqa", 398 | exemplar_split="train", 399 | eval_split="test" if test else "validation", 400 | extractor=lambda row: { 401 | "parts": [ 402 | QuestionPart(row["question"], tag="Question") 403 | ], 404 | "choices": [row[f"op{chr(i+ord('a'))}"] for i in range(4)], 405 | "answer_idx": row["cop"] 406 | } 407 | ) 408 | 409 | 410 | def get_rs_dataset_info(test): 411 | return DatasetInfo( 412 | path="riddle_sense", 413 | exemplar_split="train", 414 | eval_split="test" if test else "validation", 415 | extractor=lambda row: { 416 | "parts": [ 417 | QuestionPart( 418 | text=row["question"], 419 | tag="Question" 420 | ) 421 | ], 422 | "choices": row["choices"]["text"], 423 | "answer_idx": ( 424 | None if row["answerKey"] == "" 425 | else row["choices"]["label"].index(row["answerKey"]) 426 | ) 427 | } 428 | ) 429 | 430 | 431 | def get_winogrande_dataset_info(test, xs): 432 | return DatasetInfo( 433 | path="winogrande", 434 | name="winogrande_xs" if xs else "winogrande_xl", 435 | exemplar_split="train", 436 | eval_split="test" if test else "validation", 437 | extractor=lambda row: { 438 | "parts": [ 439 | QuestionPart(row["sentence"], tag="Question") 440 | ], 441 | "choices": [row[f"option{i+1}"] for i in range(2)], 442 | "answer_idx": ( 443 | None if row["answer"] == "" else 444 | int(row["answer"]) - 1 445 | ) 446 | } 447 | ) 448 | 449 | 450 | def do_caps_corrupt(s): 451 | new_s = "" 452 | for c in s: 453 | # If the character is a letter, flip a coin to decide whether to 454 | # capitalize it. 455 | if c.isalpha(): 456 | new_s += c.upper() if random.random() < 0.5 else c.lower() 457 | else: 458 | new_s += c 459 | return new_s 460 | 461 | 462 | def do_space_corrupt(s): 463 | words = s.split() 464 | new_words = [] 465 | for w in words: 466 | if len(w) > 2: 467 | # Add a space at a random position 468 | pos = random.randint(0, len(w)) 469 | new_words.append(w[:pos] + " " + w[pos:]) 470 | else: 471 | new_words.append(w) 472 | return " ".join(new_words) 473 | 474 | 475 | def get_corrupt_fn_by_name(name): 476 | if name == "caps": 477 | return do_caps_corrupt 478 | elif name == "space": 479 | return do_space_corrupt 480 | else: 481 | raise ValueError(f"Unknown corruption type {name}") 482 | 483 | 484 | def get_obqa_corrupt_dataset_info(corruption_type): 485 | random.seed(REPRODUCIBILITY_SEED) 486 | corrupt_fn = get_corrupt_fn_by_name(name=corruption_type) 487 | return DatasetInfo( 488 | path="openbookqa", 489 | name="main", 490 | exemplar_split="train", 491 | eval_split="test", 492 | extractor=lambda row: { 493 | "parts": [ 494 | QuestionPart(text=row["question_stem"], tag="Question") 495 | ], 496 | "choices": [corrupt_fn(s) for s in row["choices"]["text"]], 497 | "answer_idx": row["choices"]["label"].index(row["answerKey"]) 498 | } 499 | ) 500 | 501 | 502 | def load_mini_rm_caps_corrupt(n_shots, do_strong_shuffle): 503 | random.seed(REPRODUCIBILITY_SEED) 504 | corrupt_fn = get_corrupt_fn_by_name(name="caps") 505 | info = DatasetInfo( 506 | path="race", 507 | name="middle", 508 | exemplar_split="train", 509 | eval_split="test", 510 | extractor=lambda row: { 511 | "parts": [ 512 | QuestionPart(text=row["article"], tag="Passage"), 513 | QuestionPart(text=row["question"], tag="Question") 514 | ], 515 | "choices": [corrupt_fn(c) for c in row["options"]], 516 | "answer_idx": ltr_to_idx(row["answer"]) 517 | } 518 | ) 519 | qwes = get_questions_with_exemplars( 520 | info=info, 521 | n_shots=n_shots, 522 | do_strong_shuffle=do_strong_shuffle 523 | ) 524 | random.seed(REPRODUCIBILITY_SEED) 525 | return random.sample(qwes, 500) 526 | 527 | 528 | def load_mini_rm_space_corrupt(n_shots, do_strong_shuffle): 529 | random.seed(REPRODUCIBILITY_SEED) 530 | corrupt_fn = get_corrupt_fn_by_name(name="space") 531 | info = DatasetInfo( 532 | path="race", 533 | name="middle", 534 | exemplar_split="train", 535 | eval_split="test", 536 | extractor=lambda row: { 537 | "parts": [ 538 | QuestionPart(text=row["article"], tag="Passage"), 539 | QuestionPart(text=row["question"], tag="Question") 540 | ], 541 | "choices": [corrupt_fn(c) for c in row["options"]], 542 | "answer_idx": ltr_to_idx(row["answer"]) 543 | } 544 | ) 545 | qwes = get_questions_with_exemplars( 546 | info=info, 547 | n_shots=n_shots, 548 | do_strong_shuffle=do_strong_shuffle 549 | ) 550 | random.seed(REPRODUCIBILITY_SEED) 551 | return random.sample(qwes, 500) 552 | 553 | 554 | def load_mini_sc_caps_corrupt(n_shots, do_strong_shuffle): 555 | random.seed(REPRODUCIBILITY_SEED) 556 | corrupt_fn = get_corrupt_fn_by_name(name="caps") 557 | info = DatasetInfo( 558 | path="story_cloze", 559 | name="2016", 560 | data_dir="sc_data", 561 | exemplar_split="validation", 562 | eval_split="test", 563 | extractor=lambda row: { 564 | "parts": [ 565 | QuestionPart( 566 | text=" ".join( 567 | [row[f"input_sentence_{i+1}"] for i in range(4)] 568 | ), 569 | tag="Story" 570 | ), 571 | QuestionPart( 572 | text=( 573 | "Which sentence best completes the story?" 574 | ), 575 | tag="Question" 576 | ) 577 | ], 578 | "choices": [ 579 | corrupt_fn(row[f"sentence_quiz{i+1}"]) for i in range(2) 580 | ], 581 | "answer_idx": row["answer_right_ending"] - 1 582 | } 583 | ) 584 | qwes = get_questions_with_exemplars( 585 | info=info, 586 | n_shots=n_shots, 587 | do_strong_shuffle=do_strong_shuffle 588 | ) 589 | random.seed(REPRODUCIBILITY_SEED) 590 | return random.sample(qwes, 500) 591 | 592 | 593 | def load_mini_sc_space_corrupt(n_shots, do_strong_shuffle): 594 | random.seed(REPRODUCIBILITY_SEED) 595 | corrupt_fn = get_corrupt_fn_by_name(name="space") 596 | info = DatasetInfo( 597 | path="story_cloze", 598 | name="2016", 599 | data_dir="sc_data", 600 | exemplar_split="validation", 601 | eval_split="test", 602 | extractor=lambda row: { 603 | "parts": [ 604 | QuestionPart( 605 | text=" ".join( 606 | [row[f"input_sentence_{i+1}"] for i in range(4)] 607 | ), 608 | tag="Story" 609 | ), 610 | QuestionPart( 611 | text=( 612 | "Which sentence best completes the story?" 613 | ), 614 | tag="Question" 615 | ) 616 | ], 617 | "choices": [ 618 | corrupt_fn(row[f"sentence_quiz{i+1}"]) for i in range(2) 619 | ], 620 | "answer_idx": row["answer_right_ending"] - 1 621 | } 622 | ) 623 | qwes = get_questions_with_exemplars( 624 | info=info, 625 | n_shots=n_shots, 626 | do_strong_shuffle=do_strong_shuffle 627 | ) 628 | random.seed(REPRODUCIBILITY_SEED) 629 | return random.sample(qwes, 500) 630 | 631 | 632 | def get_mini_rm_corrupt_dataset_info(corruption_type): 633 | random.seed(REPRODUCIBILITY_SEED) 634 | corrupt_fn = get_corrupt_fn_by_name(name=corruption_type) 635 | return DatasetInfo( 636 | path="race", 637 | name="middle", 638 | exemplar_split="train", 639 | eval_split="test", 640 | extractor=lambda row: { 641 | "parts": [ 642 | QuestionPart(text=row["article"], tag="Passage"), 643 | QuestionPart(text=row["question"], tag="Question") 644 | ], 645 | "choices": [corrupt_fn(c) for c in row["options"]], 646 | "answer_idx": ltr_to_idx(row["answer"]) 647 | } 648 | ) 649 | 650 | 651 | def get_dataset_info(ds_name): 652 | return { 653 | "obqa": DatasetInfo( 654 | path="openbookqa", 655 | name="main", 656 | exemplar_split="train", 657 | eval_split="test", 658 | extractor=lambda row: { 659 | "parts": [ 660 | QuestionPart(text=row["question_stem"], tag="Question") 661 | ], 662 | "choices": row["choices"]["text"], 663 | "answer_idx": row["choices"]["label"].index(row["answerKey"]) 664 | } 665 | ), 666 | "ae": DatasetInfo( 667 | path="ai2_arc", 668 | name="ARC-Easy", 669 | exemplar_split="train", 670 | eval_split="test", 671 | extractor=lambda row: { 672 | "parts": [ 673 | QuestionPart(text=row["question"], tag="Question") 674 | ], 675 | "choices": row["choices"]["text"], 676 | "answer_idx": row["choices"]["label"].index(row["answerKey"]) 677 | } 678 | ), 679 | "ac": DatasetInfo( 680 | path="ai2_arc", 681 | name="ARC-Challenge", 682 | exemplar_split="train", 683 | eval_split="test", 684 | extractor=lambda row: { 685 | "parts": [ 686 | QuestionPart(text=row["question"], tag="Question") 687 | ], 688 | "choices": row["choices"]["text"], 689 | "answer_idx": row["choices"]["label"].index(row["answerKey"]) 690 | } 691 | ), 692 | "csqa": get_csqa_dataset_info(test=False), 693 | "csqa_ts": get_csqa_dataset_info(test=True), 694 | "rh": DatasetInfo( 695 | path="race", 696 | name="high", 697 | exemplar_split="train", 698 | eval_split="test", 699 | extractor=lambda row: { 700 | "parts": [ 701 | QuestionPart(text=row["article"], tag="Passage"), 702 | QuestionPart(text=row["question"], tag="Question") 703 | ], 704 | "choices": row["options"], 705 | "answer_idx": ltr_to_idx(row["answer"]) 706 | } 707 | ), 708 | "rm": DatasetInfo( 709 | path="race", 710 | name="middle", 711 | exemplar_split="train", 712 | eval_split="test", 713 | extractor=lambda row: { 714 | "parts": [ 715 | QuestionPart(text=row["article"], tag="Passage"), 716 | QuestionPart(text=row["question"], tag="Question") 717 | ], 718 | "choices": row["options"], 719 | "answer_idx": ltr_to_idx(row["answer"]) 720 | } 721 | ), 722 | "siqa": get_siqa_dataset_info(test=False), 723 | "siqa_ts": get_siqa_dataset_info(test=True), 724 | "copa": get_copa_dataset_info(test=False), 725 | "copa_ts": get_copa_dataset_info(test=True), 726 | "figqa": get_figqa_dataset_info(test=False), 727 | "figqa_ts": get_figqa_dataset_info(test=True), 728 | "rs": get_rs_dataset_info(test=False), 729 | "rs_ts": get_rs_dataset_info(test=True), 730 | "agn": DatasetInfo( 731 | path="ag_news", 732 | exemplar_split="train", 733 | eval_split="test", 734 | extractor=lambda row: { 735 | "parts": [ 736 | QuestionPart(row["text"], tag="Article"), 737 | QuestionPart( 738 | text=( 739 | "What is the best classification " 740 | "for this article?" 741 | ), 742 | tag="Question" 743 | ) 744 | ], 745 | "choices": ["World", "Sports", "Business", "Sci/Tech"], 746 | "answer_idx": row["label"] 747 | } 748 | ), 749 | "sc": DatasetInfo( 750 | path="story_cloze", 751 | name="2016", 752 | data_dir="sc_data", 753 | exemplar_split="validation", 754 | eval_split="test", 755 | extractor=lambda row: { 756 | "parts": [ 757 | QuestionPart( 758 | text=" ".join( 759 | [row[f"input_sentence_{i+1}"] for i in range(4)] 760 | ), 761 | tag="Story" 762 | ), 763 | QuestionPart( 764 | text=( 765 | "Which sentence best completes the story?" 766 | ), 767 | tag="Question" 768 | ) 769 | ], 770 | "choices": [row[f"sentence_quiz{i+1}"] for i in range(2)], 771 | "answer_idx": row["answer_right_ending"] - 1 772 | } 773 | ), 774 | "a1": get_anli_dataset_info(1), 775 | "a2": get_anli_dataset_info(2), 776 | "a3": get_anli_dataset_info(3), 777 | "medmcqa": get_medmcqa_dataset_info(test=False), 778 | "medmcqa_ts": get_medmcqa_dataset_info(test=True), 779 | "dream": DatasetInfo( 780 | path="dream", 781 | exemplar_split="train", 782 | eval_split="test", 783 | extractor=lambda row: { 784 | "parts": [ 785 | QuestionPart(" ".join(row["dialogue"]), tag="Dialogue"), 786 | QuestionPart(row["question"], tag="Question") 787 | ], 788 | "choices": row["choice"], 789 | "answer_idx": row["choice"].index(row["answer"]) 790 | } 791 | ), 792 | "codah": DatasetInfo( 793 | path="codah", 794 | name="codah", 795 | exemplar_split="train", 796 | eval_split="train", 797 | extractor=lambda row: { 798 | "parts": [ 799 | QuestionPart(row["question_propmt"], tag="Question") 800 | ], 801 | "choices": row["candidate_answers"], 802 | "answer_idx": row["correct_answer_idx"] 803 | } 804 | ), 805 | "piqa": get_piqa_dataset_info(test=False), 806 | "piqa_ts": get_piqa_dataset_info(test=True), 807 | "w": get_winogrande_dataset_info(test=False, xs=False), 808 | "w_ts": get_winogrande_dataset_info(test=True, xs=False), 809 | "cqa": get_cqa_dataset_info(test=False), 810 | "cqa_ts": get_cqa_dataset_info(test=True), 811 | "mmlu": load_mmlu, 812 | "wxs": get_winogrande_dataset_info(test=False, xs=True), 813 | "wxs_ts": get_winogrande_dataset_info(test=True, xs=True), 814 | "tiny_obqa": load_tiny_obqa, 815 | "lqa": load_lqa, 816 | "hs": get_hs_dataset_info(test=False), 817 | "hs_ts": get_hs_dataset_info(test=True), 818 | "mini_rm": load_mini_rm, 819 | "mini_sc": load_mini_sc, 820 | "a1s": get_anli_shuffled_dataset_info(1), 821 | "a2s": get_anli_shuffled_dataset_info(2), 822 | "a3s": get_anli_shuffled_dataset_info(3), 823 | "obqa_caps_corrupt": get_obqa_corrupt_dataset_info( 824 | corruption_type="caps" 825 | ), 826 | "obqa_space_corrupt": get_obqa_corrupt_dataset_info( 827 | corruption_type="space" 828 | ), 829 | "mini_sc_caps_corrupt": load_mini_sc_caps_corrupt, 830 | "mini_sc_space_corrupt": load_mini_sc_space_corrupt, 831 | "mini_rm_caps_corrupt": load_mini_rm_caps_corrupt, 832 | "mini_rm_space_corrupt": load_mini_rm_space_corrupt 833 | }[ds_name] 834 | --------------------------------------------------------------------------------