├── .gitignore ├── LICENSE ├── README.md ├── acl2023.jpg ├── custom_test.py ├── custom_train.py ├── data ├── .gitignore ├── few_shot_cot_prompts.json └── split │ ├── addsub__default.json │ ├── aqua__default.json │ ├── coin_flip__default.json │ ├── commonsense_qa__default.json │ ├── date_understanding__default.json │ ├── date_understanding__template.json │ ├── gsm8k__default.json │ ├── last_letter_concatenation__default.json │ ├── multiarith__default.json │ ├── multiarith__template.json │ ├── single_eq__default.json │ ├── strategy_qa__default.json │ ├── svamp__default.json │ └── tracking_shuffled_objects__default.json ├── notebooks ├── .gitignore ├── example_load_results.ipynb ├── example_oai_finetune_cot.ipynb ├── old │ ├── All Experiments.ipynb │ ├── Data Preparation.ipynb │ ├── Manual Inspection.ipynb │ └── README.txt └── results.ipynb ├── requirements.txt ├── scripts └── custom │ ├── example_ft5.sh │ ├── example_gpt2.sh │ └── example_t5.sh ├── setup.py ├── setup.sh └── src ├── __init__.py ├── custom ├── __init__.py ├── data_module.py ├── model.py └── utils.py ├── data ├── __init__.py ├── completion_dataset.py ├── dataset.py ├── few_shot_cot_prompt.py ├── format.py ├── generate_split.py └── split.py ├── evaluation ├── __init__.py ├── evaluator.py └── summary.py ├── oai ├── __init__.py ├── finetune.py ├── inference.py └── utils │ ├── __init__.py │ ├── api_wrapper.py │ ├── fetch_model_ids.py │ ├── metadata.py │ └── tokens.py └── paths.py /.gitignore: -------------------------------------------------------------------------------- 1 | #Saved files 2 | saved/ 3 | 4 | # Misc 5 | .DS_Store/ 6 | .idea/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Namgyu Ho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Large Language Models Are Reasoning Teachers 2 | 3 | 4 | Official repository for [Large Language Models Are Reasoning Teachers](https://arxiv.org/abs/2212.10071), by 5 | Namgyu Ho, Laura Schmid, and Se-young Yun. 6 | 7 | **🚀 Accepted to ACL 2023.** 8 | 9 | This repository contains code for (1) running CoT reasoning on OpenAI models, 10 | and (2) apply Fine-tune-CoT to train students based on OpenAI models *or* custom open-source models such as T5, Flan-T5, GPT-2 on your GPUs, based on 🤗 and Pytorch Lightning. 11 | 12 | 13 | ## Getting Started 14 | 15 | ### OpenAI API Experiments 16 | 17 | OpenAI API experiments are implemented in the `oai` module. Refer to `notebooks/example_oai_finetune_cot.ipynb` 18 | on how to run Fine-tune-CoT from start to finish. 19 | 20 | ### Updates to OpenAI API 21 | 22 | Use the [Batch API](https://platform.openai.com/docs/guides/batch) to save costs for reasoning data collection, or use [parallel API requests](https://github.com/openai/openai-cookbook/blob/main/examples/api_request_parallel_processor.py) to accelerate immediate, large-scale data collection. 23 | 24 | ### Custom Experiments (on GPU) 25 | 26 | Custom experiments are implemented in the `custom` module, based on PyTorch Lightning. Refer to `custom_train.py` 27 | and `scripts/custom/*.sh` on how to fine-tune models such as T5, Flan-T5, and GPT-2 using Fine-tune-CoT. 28 | 29 | ## Setup 30 | 31 | ``` 32 | pip install -r requirements.txt 33 | python setup.py develop 34 | ``` 35 | 36 | ### Environment 37 | 38 | The code has been tested on Python<=3.10, PyTorch Lightning<=1.9, PyTorch>=2.0 39 | 40 | ## Data 🚀 41 | 42 | We're proud to share *all* of our raw experimental data! All data is organized in json or jsonl format, for your pleasure :) 43 | 44 | Cloud storage folder links: 45 | 46 | - [Dropbox](https://www.dropbox.com/sh/hwcncpyomx87h20/AACqgVdd-ZzBQ3ncJcKqw0cVa?dl=0) 47 | - [Google Drive](https://drive.google.com/drive/folders/1C6kah3WV36N8omlUl-TeU9tsJADZNaJV?usp=share_link) 48 | 49 | ### File List 50 | 51 | - `dataset.tar.gz`: 12 task datasets compiled in a unified json format 52 | - Belongs in `PROJECT/data/dataset/` 53 | - `completion_data.tar.gz`: Completion data, i.e., inference data, from all teachers and students, for *all* experiments. About 8GB when uncompressed 54 | - Belongs in `PROJECT/saved/completion_data/` 55 | - `teacher_completion_data.tar.gz`: Completion data from Zero-shot-CoT (with diverse reasoning) on the default teacher model `text-davinci-002` using the OpenAI API. About 💰 $1000+ worth of goods, with ❤️ from [OSI LAB](http://osi.kaist.ac.kr) at [KAIST](https://kaist.ac.kr) . Subset of `completion_data.tar.gz`. 56 | - Belongs in `PROJECT/saved/completion_data/`. 57 | - `finetune_data.tar.gz`: *All* data used to fine-tune OpenAI students via the fine-tuning API, in jsonl format. These are derived from teacher completion data and can be generated from our code. 58 | - Belongs in `PROJECT/saved/finetune_data/` 59 | 60 | ### Generate Paper Results 61 | 62 | After downloading the full `completion_data.tar.gz`, you can run `notebooks/results.ipynb` to generate *all* result tables and figures from our paper. The code will (re-)evaluate all raw text model outputs contained in the completion data. 63 | 64 | 65 | 66 | ## Additional Resources 67 | 68 | ### Template-based Split (Paper Appendix E.3) 69 | 70 | Template-based splits for MultiArith and Date Understanding are saved in `/data/splits/*__template.json` 71 | 72 | ### Few-shot Prompts 73 | 74 | Few-shot prompts adapted from Wei 2022 are saved in `/data/few_shot_cot_prompts.json` 75 | 76 | 77 | 78 | ## Data Structures 79 | 80 | ### `data.dataset.Dataset` 81 | 82 | ```json 83 | { 84 | "metadata": { 85 | "dataset_key": "multiarith" 86 | }, 87 | "data": [ 88 | { 89 | "sample_index": 0, 90 | "question": "string", 91 | "answer": "string", 92 | "rationale": "string?" 93 | } 94 | ] 95 | } 96 | ``` 97 | 98 | ### `data.completion.CompletionDataset` 99 | 100 | ```json 101 | { 102 | "metadata": { 103 | "dataset_key": "multiarith", 104 | "base_model": "curie", 105 | "finetune_key": "zs_cot_multiarith", 106 | "train_key": "ft_cot", 107 | "prediction_template": "ft_cot_token", 108 | }, 109 | "data": { 110 | "": [ 111 | { 112 | "sample_index": 0, 113 | "completion_index": 0, 114 | "question": "string", 115 | "answer": "string", 116 | "prompt": "string", 117 | "completion": "string", 118 | "finish_reason": "string", 119 | "reasoning_prompt": "string?", 120 | "reasoning_completion": "string?", 121 | "reasoning_finish_reason": "string?", 122 | } 123 | ] 124 | } 125 | } 126 | ``` 127 | 128 | 129 | 130 | ## Data Organization 131 | 132 | *Needs update.* 133 | 134 | - `` = `B__T_` 135 | 136 | ### File Organization Pattern 137 | 138 | ``` 139 | saved/ 140 | |–– completion_data/ 141 | |–– B___C_/ 142 | |-- D_.json # base model inference 143 | |-- F___D_.json # default fine-tuned model inference 144 | |-- F___T___D_.json # custom fine-tuned model inference 145 | |–– finetune_data/ 146 | |–– P_/ 147 | |–– F_{.*|/} 148 | |–– model_metadata/ 149 | |–– B_ 150 | |–– F___T_.json 151 | ``` 152 | 153 | ### File Organization Examples 154 | 155 | ``` 156 | saved/ 157 | |–– completion_data/ 158 | |–– B_text-davinci-002__C_zs_cot/ 159 | |–– B_text-davinci-002__C_zs_cot_long/ 160 | |–– B_text-davinci-002__C_fs_cot/ 161 | |–– B_curie__C_zs_cot/ 162 | |–– B_curie__C_fs_cot/ 163 | |–– B_curie__C_zs/ 164 | |–– B_curie__C_ft_cot/ 165 | |–– finetune_data/ 166 | |–– F_zs_cot_multiarith/ # text-davinci-002_zs_cot 167 | |–– F_zs_cot_long_multiarith/ 168 | |–– model_metadata/ 169 | |–– B_curie/ 170 | |–– F_zs_cot_multiarith.json 171 | ``` 172 | 173 | 174 | ### Personal Note 175 | 176 | ![accepted](acl2023.jpg) 177 | 178 | -------------------------------------------------------------------------------- /acl2023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itsnamgyu/reasoning-teacher/a2ca4d28d3bbabbd77106b76d06885d1a5eac0d9/acl2023.jpg -------------------------------------------------------------------------------- /custom_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run custom inference experiments, i.e., prompting open models such as T5, and GPT-2 on GPUs. 3 | Currently only supports zero-shot prompting and few-shot CoT prompting 4 | 5 | Note, to check distributed errors used `TORCH_DISTRIBUTED_DEBUG=DETAIL` 6 | Note, if deepspeed hangs at initialization, use `NCCL_P2P_DISABLE=1`. Thought, this seems to slow down the training a lot... 7 | Note, to see more NCCL errors, use NCCL_DEBUG=WARN 8 | """ 9 | import argparse 10 | import logging 11 | import os 12 | 13 | from custom.data_module import DataModule 14 | from data.completion_dataset import CompletionMetadata 15 | 16 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 17 | 18 | import pytorch_lightning as pl 19 | import torch 20 | from transformers import T5TokenizerFast, T5ForConditionalGeneration 21 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 22 | 23 | from custom.model import Model 24 | 25 | logging.basicConfig(level=logging.INFO) 26 | 27 | torch.set_float32_matmul_precision("high") 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--dataset_key", type=str, default="multiarith") 32 | parser.add_argument("--model_key", type=str, default="t5_base") 33 | parser.add_argument("--batch_size", type=int, default=64) 34 | parser.add_argument("--preset_key", type=str, default="zs") 35 | parser.add_argument("--devices", type=int, nargs="+", default=[0]) 36 | parser.add_argument("--precision", type=int, default=32) 37 | args = parser.parse_args() 38 | print("arguments".upper().center(80, "-")) 39 | print(args) 40 | print("-" * 80) 41 | 42 | if args.precision == 16: 43 | args.precision = "bf16" 44 | print("Setting precision to bf16") 45 | 46 | dataset_key = args.dataset_key 47 | model_key = args.model_key 48 | 49 | if "flan" in model_key: 50 | hf_key = "google/{}".format(model_key.replace("_", "-")) 51 | model = AutoModelForSeq2SeqLM.from_pretrained(hf_key) 52 | tokenizer = AutoTokenizer.from_pretrained(hf_key, model_max_length=512) 53 | model_type = "encoder_decoder" 54 | append_eos = False # t5 tokenizers already append eos 55 | elif "t5" in model_key: 56 | hf_key = model_key.replace("_", "-") 57 | model = T5ForConditionalGeneration.from_pretrained(hf_key) 58 | tokenizer = T5TokenizerFast.from_pretrained(hf_key, model_max_length=512) 59 | model_type = "encoder_decoder" 60 | append_eos = False 61 | elif "gpt2" in model_key: 62 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 63 | 64 | hf_key = model_key.replace("_", "-") 65 | tokenizer = GPT2Tokenizer.from_pretrained(hf_key) 66 | model = GPT2LMHeadModel.from_pretrained(hf_key) 67 | model_type = "decoder" 68 | append_eos = True 69 | else: 70 | raise NotImplementedError(model_key) 71 | 72 | if args.preset_key == "zs": 73 | completion_key = "zs" 74 | elif args.preset_key == "zs_cot": 75 | completion_key = "zs_cot" 76 | elif args.preset_key == "fs_cot": 77 | completion_key = "fs_cot" 78 | else: 79 | raise NotImplementedError(args.preset_key) 80 | 81 | if tokenizer.pad_token is None: 82 | tokenizer.pad_token = tokenizer.eos_token 83 | 84 | batch_size = args.batch_size 85 | data_module = DataModule(dataset_key, args.preset_key, tokenizer, model_type, 86 | inference_batch_size=batch_size, num_workers=8, append_eos=append_eos) 87 | 88 | cm = CompletionMetadata(model_key, completion_key, dataset_key, prediction_template=data_module.prediction_template) 89 | lm = Model(model, tokenizer, model_type, completion_metadata=cm, truncate_early=False) 90 | 91 | if not os.path.exists("external_lightning_logs"): 92 | raise Exception("external_lightning_logs/ does not exist") 93 | default_root_dir = os.path.join("external_lightning_logs", "{}_{}".format(model_key, dataset_key)) 94 | trainer = pl.Trainer(accelerator="gpu", devices=args.devices, default_root_dir=default_root_dir, precision=args.precision) 95 | 96 | trainer.validate(lm, datamodule=data_module) 97 | -------------------------------------------------------------------------------- /custom_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run custom fine-tuning based experiments, i.e., fine-tuning models such as T5, and GPT-2 on GPUs. 3 | 4 | Note, to check distributed errors used `TORCH_DISTRIBUTED_DEBUG=DETAIL` 5 | Note, if deepspeed hangs at initialization, use `NCCL_P2P_DISABLE=1`. Thought, this seems to slow down the training a lot... 6 | Note, to see more NCCL errors, use NCCL_DEBUG=WARN 7 | """ 8 | import argparse 9 | import logging 10 | import os 11 | 12 | from custom.data_module import DataModule 13 | from data.completion_dataset import CompletionMetadata 14 | 15 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 16 | 17 | import pytorch_lightning as pl 18 | import torch 19 | from transformers import T5TokenizerFast, T5ForConditionalGeneration 20 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 21 | 22 | from custom.model import Model 23 | 24 | logging.basicConfig(level=logging.INFO) 25 | 26 | torch.set_float32_matmul_precision("high") 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--dataset_key", type=str, default="multiarith") 31 | parser.add_argument("--model_key", type=str, default="t5_base") 32 | parser.add_argument("--train_key", type=str, required=True) 33 | parser.add_argument("--batch_size", type=int, default=8) 34 | parser.add_argument("--preset_key", type=str, default="ft_cot") 35 | parser.add_argument("--inference_batch_size", type=int, default=None) 36 | parser.add_argument("--devices", type=int, nargs="+", default=[0, 1]) 37 | parser.add_argument("--accumulate", type=int, default=1) 38 | parser.add_argument("--strategy", type=str, default=None) 39 | parser.add_argument("--precision", type=int, default=32) 40 | parser.add_argument("--lr", type=float, default=3e-4) 41 | parser.add_argument("--disable_checkpointing", action="store_true") 42 | args = parser.parse_args() 43 | args.enable_checkpointing = not args.disable_checkpointing 44 | print("arguments".upper().center(80, "-")) 45 | print(args) 46 | print("-" * 80) 47 | 48 | if args.precision == 16: 49 | args.precision = "bf16" 50 | print("Setting precision to bf16") 51 | 52 | dataset_key = args.dataset_key 53 | model_key = args.model_key 54 | train_key = args.train_key 55 | 56 | if "flan" in model_key: 57 | hf_key = "google/{}".format(model_key.replace("_", "-")) 58 | model = AutoModelForSeq2SeqLM.from_pretrained(hf_key) 59 | tokenizer = AutoTokenizer.from_pretrained(hf_key, model_max_length=512) 60 | model_type = "encoder_decoder" 61 | append_eos = False # t5 tokenizers already append eos 62 | elif "t5" in model_key: 63 | hf_key = model_key.replace("_", "-") 64 | model = T5ForConditionalGeneration.from_pretrained(hf_key) 65 | tokenizer = T5TokenizerFast.from_pretrained(hf_key, model_max_length=512) 66 | model_type = "encoder_decoder" 67 | append_eos = False 68 | elif "gpt2" in model_key: 69 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 70 | 71 | hf_key = model_key.replace("_", "-") 72 | tokenizer = GPT2Tokenizer.from_pretrained(hf_key) 73 | model = GPT2LMHeadModel.from_pretrained(hf_key) 74 | model_type = "decoder" 75 | append_eos = True 76 | else: 77 | raise NotImplementedError(model_key) 78 | 79 | if "ft_cot" in args.preset_key: 80 | completion_key = "ft_cot" 81 | elif args.preset_key == "ft": 82 | completion_key = "ft" 83 | elif args.preset_key == "fs_cot": 84 | raise NotImplementedError("We don't train models on fs_cot") 85 | else: 86 | raise NotImplementedError(args.preset_key) 87 | 88 | if tokenizer.pad_token is None: 89 | tokenizer.pad_token = tokenizer.eos_token 90 | 91 | batch_size = args.batch_size 92 | if args.inference_batch_size is None: 93 | inference_batch_size = batch_size 94 | else: 95 | inference_batch_size = args.inference_batch_size 96 | data_module = DataModule(dataset_key, args.preset_key, tokenizer, model_type, batch_size=batch_size, 97 | inference_batch_size=inference_batch_size, num_workers=8, append_eos=append_eos) 98 | 99 | cm = CompletionMetadata(model_key, completion_key, dataset_key, data_module.finetune_key, 100 | data_module.prediction_template, train_key=args.train_key) 101 | use_cpu_offload = args.strategy and "offload" in args.strategy 102 | lm = Model(model, tokenizer, model_type, use_cpu_offload=use_cpu_offload, completion_metadata=cm, lr=args.lr) 103 | 104 | if not os.path.exists("external_lightning_logs"): 105 | raise Exception("external_lightning_logs/ does not exist") 106 | default_root_dir = os.path.join("external_lightning_logs", "{}_{}_{}".format(model_key, dataset_key, train_key)) 107 | trainer = pl.Trainer(accelerator="gpu", devices=args.devices, strategy=args.strategy, 108 | default_root_dir=default_root_dir, min_epochs=20, max_epochs=20, 109 | accumulate_grad_batches=args.accumulate, precision=args.precision, 110 | enable_checkpointing=args.enable_checkpointing) 111 | 112 | trainer.fit(lm, datamodule=data_module) 113 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | dataset -------------------------------------------------------------------------------- /data/split/addsub__default.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8, 11 | 10, 12 | 11, 13 | 12, 14 | 13, 15 | 14, 16 | 15, 17 | 16, 18 | 17, 19 | 18, 20 | 19, 21 | 20, 22 | 21, 23 | 22, 24 | 24, 25 | 26, 26 | 27, 27 | 29, 28 | 30, 29 | 33, 30 | 34, 31 | 35, 32 | 37, 33 | 40, 34 | 44, 35 | 45, 36 | 46, 37 | 49, 38 | 51, 39 | 52, 40 | 54, 41 | 55, 42 | 56, 43 | 59, 44 | 60, 45 | 61, 46 | 63, 47 | 64, 48 | 65, 49 | 66, 50 | 67, 51 | 68, 52 | 71, 53 | 73, 54 | 74, 55 | 75, 56 | 76, 57 | 77, 58 | 78, 59 | 79, 60 | 80, 61 | 81, 62 | 83, 63 | 85, 64 | 86, 65 | 89, 66 | 90, 67 | 92, 68 | 93, 69 | 96, 70 | 97, 71 | 100, 72 | 101, 73 | 102, 74 | 103, 75 | 104, 76 | 106, 77 | 107, 78 | 108, 79 | 109, 80 | 110, 81 | 111, 82 | 112, 83 | 113, 84 | 114, 85 | 116, 86 | 118, 87 | 120, 88 | 122, 89 | 124, 90 | 125, 91 | 126, 92 | 129, 93 | 132, 94 | 133, 95 | 134, 96 | 135, 97 | 136, 98 | 137, 99 | 138, 100 | 139, 101 | 140, 102 | 141, 103 | 142, 104 | 144, 105 | 145, 106 | 146, 107 | 149, 108 | 150, 109 | 152, 110 | 153, 111 | 154, 112 | 155, 113 | 156, 114 | 157, 115 | 158, 116 | 159, 117 | 160, 118 | 161, 119 | 162, 120 | 164, 121 | 166, 122 | 167, 123 | 168, 124 | 170, 125 | 171, 126 | 173, 127 | 175, 128 | 176, 129 | 179, 130 | 181, 131 | 182, 132 | 184, 133 | 186, 134 | 188, 135 | 189, 136 | 190, 137 | 191, 138 | 194, 139 | 196, 140 | 198, 141 | 199, 142 | 200, 143 | 204, 144 | 205, 145 | 206, 146 | 208, 147 | 210, 148 | 212, 149 | 213, 150 | 214, 151 | 215, 152 | 216, 153 | 217, 154 | 218, 155 | 219, 156 | 220, 157 | 221, 158 | 223, 159 | 224, 160 | 225, 161 | 226, 162 | 228, 163 | 229, 164 | 230, 165 | 231, 166 | 232, 167 | 233, 168 | 234, 169 | 235, 170 | 236, 171 | 237, 172 | 238, 173 | 239, 174 | 240, 175 | 241, 176 | 245, 177 | 246, 178 | 247, 179 | 248, 180 | 249, 181 | 250, 182 | 252, 183 | 253, 184 | 254, 185 | 255, 186 | 258, 187 | 259, 188 | 260, 189 | 261, 190 | 263, 191 | 264, 192 | 268, 193 | 269, 194 | 271, 195 | 272, 196 | 274, 197 | 275, 198 | 276, 199 | 278, 200 | 280, 201 | 281, 202 | 282, 203 | 283, 204 | 284, 205 | 286, 206 | 287, 207 | 289, 208 | 293, 209 | 294, 210 | 295, 211 | 296, 212 | 297, 213 | 299, 214 | 300, 215 | 301, 216 | 302, 217 | 303, 218 | 304, 219 | 306, 220 | 307, 221 | 308, 222 | 309, 223 | 310, 224 | 311, 225 | 315, 226 | 316, 227 | 317, 228 | 318, 229 | 319, 230 | 320, 231 | 322, 232 | 324, 233 | 325, 234 | 327, 235 | 328, 236 | 329, 237 | 330, 238 | 331, 239 | 332, 240 | 336, 241 | 339, 242 | 341, 243 | 343, 244 | 344, 245 | 345, 246 | 347, 247 | 349, 248 | 351, 249 | 353, 250 | 354, 251 | 355, 252 | 356, 253 | 358, 254 | 360, 255 | 362, 256 | 366, 257 | 368, 258 | 369, 259 | 370, 260 | 371, 261 | 372, 262 | 373, 263 | 374, 264 | 375, 265 | 376, 266 | 378, 267 | 380, 268 | 383, 269 | 384, 270 | 385, 271 | 386, 272 | 387, 273 | 388, 274 | 389, 275 | 390, 276 | 391, 277 | 392, 278 | 394 279 | ], 280 | "test": [ 281 | 0, 282 | 9, 283 | 23, 284 | 25, 285 | 28, 286 | 31, 287 | 32, 288 | 36, 289 | 38, 290 | 39, 291 | 41, 292 | 42, 293 | 43, 294 | 47, 295 | 48, 296 | 50, 297 | 53, 298 | 57, 299 | 58, 300 | 62, 301 | 69, 302 | 70, 303 | 72, 304 | 82, 305 | 84, 306 | 87, 307 | 88, 308 | 91, 309 | 94, 310 | 95, 311 | 98, 312 | 99, 313 | 105, 314 | 115, 315 | 117, 316 | 119, 317 | 121, 318 | 123, 319 | 127, 320 | 128, 321 | 130, 322 | 131, 323 | 143, 324 | 147, 325 | 148, 326 | 151, 327 | 163, 328 | 165, 329 | 169, 330 | 172, 331 | 174, 332 | 177, 333 | 178, 334 | 180, 335 | 183, 336 | 185, 337 | 187, 338 | 192, 339 | 193, 340 | 195, 341 | 197, 342 | 201, 343 | 202, 344 | 203, 345 | 207, 346 | 209, 347 | 211, 348 | 222, 349 | 227, 350 | 242, 351 | 243, 352 | 244, 353 | 251, 354 | 256, 355 | 257, 356 | 262, 357 | 265, 358 | 266, 359 | 267, 360 | 270, 361 | 273, 362 | 277, 363 | 279, 364 | 285, 365 | 288, 366 | 290, 367 | 291, 368 | 292, 369 | 298, 370 | 305, 371 | 312, 372 | 313, 373 | 314, 374 | 321, 375 | 323, 376 | 326, 377 | 333, 378 | 334, 379 | 335, 380 | 337, 381 | 338, 382 | 340, 383 | 342, 384 | 346, 385 | 348, 386 | 350, 387 | 352, 388 | 357, 389 | 359, 390 | 361, 391 | 363, 392 | 364, 393 | 365, 394 | 367, 395 | 377, 396 | 379, 397 | 381, 398 | 382, 399 | 393 400 | ] 401 | } -------------------------------------------------------------------------------- /data/split/coin_flip__default.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8, 11 | 10, 12 | 12, 13 | 13, 14 | 14, 15 | 15, 16 | 16, 17 | 17, 18 | 18, 19 | 19, 20 | 20, 21 | 21, 22 | 22, 23 | 24, 24 | 26, 25 | 27, 26 | 29, 27 | 30, 28 | 33, 29 | 34, 30 | 35, 31 | 37, 32 | 40, 33 | 44, 34 | 45, 35 | 46, 36 | 49, 37 | 51, 38 | 52, 39 | 54, 40 | 55, 41 | 56, 42 | 59, 43 | 60, 44 | 61, 45 | 63, 46 | 64, 47 | 65, 48 | 66, 49 | 67, 50 | 68, 51 | 71, 52 | 73, 53 | 74, 54 | 75, 55 | 76, 56 | 77, 57 | 78, 58 | 79, 59 | 80, 60 | 81, 61 | 83, 62 | 85, 63 | 89, 64 | 90, 65 | 92, 66 | 93, 67 | 96, 68 | 97, 69 | 100, 70 | 101, 71 | 102, 72 | 103, 73 | 104, 74 | 106, 75 | 107, 76 | 108, 77 | 109, 78 | 110, 79 | 111, 80 | 112, 81 | 113, 82 | 114, 83 | 116, 84 | 118, 85 | 120, 86 | 122, 87 | 124, 88 | 125, 89 | 126, 90 | 129, 91 | 132, 92 | 133, 93 | 134, 94 | 135, 95 | 136, 96 | 137, 97 | 138, 98 | 139, 99 | 140, 100 | 141, 101 | 142, 102 | 144, 103 | 145, 104 | 146, 105 | 149, 106 | 150, 107 | 152, 108 | 153, 109 | 154, 110 | 155, 111 | 156, 112 | 157, 113 | 158, 114 | 159, 115 | 160, 116 | 161, 117 | 162, 118 | 164, 119 | 166, 120 | 167, 121 | 168, 122 | 170, 123 | 171, 124 | 173, 125 | 175, 126 | 176, 127 | 179, 128 | 181, 129 | 182, 130 | 184, 131 | 186, 132 | 188, 133 | 189, 134 | 190, 135 | 191, 136 | 194, 137 | 196, 138 | 198, 139 | 199, 140 | 200, 141 | 204, 142 | 205, 143 | 206, 144 | 208, 145 | 210, 146 | 212, 147 | 213, 148 | 214, 149 | 215, 150 | 216, 151 | 217, 152 | 218, 153 | 219, 154 | 220, 155 | 221, 156 | 223, 157 | 224, 158 | 225, 159 | 226, 160 | 228, 161 | 229, 162 | 230, 163 | 231, 164 | 232, 165 | 233, 166 | 234, 167 | 235, 168 | 236, 169 | 237, 170 | 238, 171 | 239, 172 | 240, 173 | 241, 174 | 245, 175 | 246, 176 | 247, 177 | 248, 178 | 249, 179 | 250, 180 | 252, 181 | 253, 182 | 254, 183 | 255, 184 | 258, 185 | 259, 186 | 260, 187 | 261, 188 | 263, 189 | 264, 190 | 268, 191 | 269, 192 | 271, 193 | 272, 194 | 274, 195 | 275, 196 | 276, 197 | 278, 198 | 280, 199 | 281, 200 | 282, 201 | 283, 202 | 284, 203 | 286, 204 | 287, 205 | 289, 206 | 293, 207 | 294, 208 | 295, 209 | 296, 210 | 297, 211 | 298, 212 | 299, 213 | 300, 214 | 301, 215 | 302, 216 | 303, 217 | 306, 218 | 307, 219 | 308, 220 | 309, 221 | 310, 222 | 311, 223 | 312, 224 | 313, 225 | 315, 226 | 316, 227 | 317, 228 | 318, 229 | 319, 230 | 320, 231 | 322, 232 | 325, 233 | 326, 234 | 327, 235 | 328, 236 | 329, 237 | 330, 238 | 331, 239 | 332, 240 | 336, 241 | 339, 242 | 340, 243 | 342, 244 | 343, 245 | 344, 246 | 345, 247 | 346, 248 | 347, 249 | 348, 250 | 350, 251 | 351, 252 | 352, 253 | 353, 254 | 354, 255 | 355, 256 | 357, 257 | 360, 258 | 361, 259 | 362, 260 | 363, 261 | 364, 262 | 365, 263 | 366, 264 | 367, 265 | 372, 266 | 374, 267 | 375, 268 | 378, 269 | 379, 270 | 380, 271 | 381, 272 | 382, 273 | 383, 274 | 385, 275 | 386, 276 | 387, 277 | 390, 278 | 391, 279 | 392, 280 | 393, 281 | 394, 282 | 395, 283 | 399, 284 | 400, 285 | 401, 286 | 402, 287 | 403, 288 | 405, 289 | 406, 290 | 407, 291 | 408, 292 | 409, 293 | 411, 294 | 412, 295 | 413, 296 | 414, 297 | 415, 298 | 417, 299 | 418, 300 | 419, 301 | 421, 302 | 422, 303 | 424, 304 | 425, 305 | 427, 306 | 428, 307 | 429, 308 | 430, 309 | 432, 310 | 433, 311 | 435, 312 | 436, 313 | 438, 314 | 439, 315 | 440, 316 | 443, 317 | 445, 318 | 447, 319 | 450, 320 | 451, 321 | 452, 322 | 453, 323 | 455, 324 | 457, 325 | 458, 326 | 459, 327 | 461, 328 | 463, 329 | 468, 330 | 469, 331 | 471, 332 | 473, 333 | 474, 334 | 475, 335 | 476, 336 | 477, 337 | 478, 338 | 479, 339 | 481, 340 | 484, 341 | 485, 342 | 488, 343 | 489, 344 | 490, 345 | 491, 346 | 492, 347 | 493, 348 | 494, 349 | 495, 350 | 496, 351 | 497, 352 | 499 353 | ], 354 | "test": [ 355 | 0, 356 | 9, 357 | 11, 358 | 23, 359 | 25, 360 | 28, 361 | 31, 362 | 32, 363 | 36, 364 | 38, 365 | 39, 366 | 41, 367 | 42, 368 | 43, 369 | 47, 370 | 48, 371 | 50, 372 | 53, 373 | 57, 374 | 58, 375 | 62, 376 | 69, 377 | 70, 378 | 72, 379 | 82, 380 | 84, 381 | 86, 382 | 87, 383 | 88, 384 | 91, 385 | 94, 386 | 95, 387 | 98, 388 | 99, 389 | 105, 390 | 115, 391 | 117, 392 | 119, 393 | 121, 394 | 123, 395 | 127, 396 | 128, 397 | 130, 398 | 131, 399 | 143, 400 | 147, 401 | 148, 402 | 151, 403 | 163, 404 | 165, 405 | 169, 406 | 172, 407 | 174, 408 | 177, 409 | 178, 410 | 180, 411 | 183, 412 | 185, 413 | 187, 414 | 192, 415 | 193, 416 | 195, 417 | 197, 418 | 201, 419 | 202, 420 | 203, 421 | 207, 422 | 209, 423 | 211, 424 | 222, 425 | 227, 426 | 242, 427 | 243, 428 | 244, 429 | 251, 430 | 256, 431 | 257, 432 | 262, 433 | 265, 434 | 266, 435 | 267, 436 | 270, 437 | 273, 438 | 277, 439 | 279, 440 | 285, 441 | 288, 442 | 290, 443 | 291, 444 | 292, 445 | 304, 446 | 305, 447 | 314, 448 | 321, 449 | 323, 450 | 324, 451 | 333, 452 | 334, 453 | 335, 454 | 337, 455 | 338, 456 | 341, 457 | 349, 458 | 356, 459 | 358, 460 | 359, 461 | 368, 462 | 369, 463 | 370, 464 | 371, 465 | 373, 466 | 376, 467 | 377, 468 | 384, 469 | 388, 470 | 389, 471 | 396, 472 | 397, 473 | 398, 474 | 404, 475 | 410, 476 | 416, 477 | 420, 478 | 423, 479 | 426, 480 | 431, 481 | 434, 482 | 437, 483 | 441, 484 | 442, 485 | 444, 486 | 446, 487 | 448, 488 | 449, 489 | 454, 490 | 456, 491 | 460, 492 | 462, 493 | 464, 494 | 465, 495 | 466, 496 | 467, 497 | 470, 498 | 472, 499 | 480, 500 | 482, 501 | 483, 502 | 486, 503 | 487, 504 | 498 505 | ] 506 | } -------------------------------------------------------------------------------- /data/split/date_understanding__default.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8, 11 | 10, 12 | 11, 13 | 12, 14 | 13, 15 | 14, 16 | 15, 17 | 16, 18 | 17, 19 | 18, 20 | 19, 21 | 20, 22 | 21, 23 | 22, 24 | 24, 25 | 26, 26 | 27, 27 | 29, 28 | 30, 29 | 33, 30 | 34, 31 | 35, 32 | 36, 33 | 37, 34 | 40, 35 | 44, 36 | 45, 37 | 46, 38 | 49, 39 | 51, 40 | 52, 41 | 54, 42 | 55, 43 | 56, 44 | 59, 45 | 60, 46 | 61, 47 | 63, 48 | 64, 49 | 65, 50 | 66, 51 | 67, 52 | 68, 53 | 71, 54 | 73, 55 | 74, 56 | 75, 57 | 76, 58 | 77, 59 | 78, 60 | 79, 61 | 80, 62 | 81, 63 | 83, 64 | 85, 65 | 86, 66 | 89, 67 | 90, 68 | 92, 69 | 93, 70 | 96, 71 | 97, 72 | 100, 73 | 101, 74 | 102, 75 | 103, 76 | 104, 77 | 106, 78 | 107, 79 | 108, 80 | 109, 81 | 110, 82 | 111, 83 | 112, 84 | 113, 85 | 114, 86 | 116, 87 | 118, 88 | 120, 89 | 122, 90 | 124, 91 | 125, 92 | 126, 93 | 129, 94 | 132, 95 | 133, 96 | 134, 97 | 135, 98 | 136, 99 | 137, 100 | 138, 101 | 139, 102 | 140, 103 | 141, 104 | 142, 105 | 144, 106 | 145, 107 | 146, 108 | 149, 109 | 150, 110 | 152, 111 | 153, 112 | 154, 113 | 155, 114 | 156, 115 | 157, 116 | 158, 117 | 159, 118 | 160, 119 | 161, 120 | 162, 121 | 164, 122 | 166, 123 | 167, 124 | 168, 125 | 170, 126 | 171, 127 | 173, 128 | 175, 129 | 176, 130 | 179, 131 | 181, 132 | 182, 133 | 184, 134 | 186, 135 | 188, 136 | 189, 137 | 190, 138 | 191, 139 | 194, 140 | 196, 141 | 198, 142 | 199, 143 | 200, 144 | 204, 145 | 205, 146 | 206, 147 | 208, 148 | 210, 149 | 212, 150 | 213, 151 | 214, 152 | 215, 153 | 216, 154 | 217, 155 | 218, 156 | 219, 157 | 220, 158 | 221, 159 | 223, 160 | 224, 161 | 225, 162 | 226, 163 | 228, 164 | 229, 165 | 230, 166 | 231, 167 | 232, 168 | 233, 169 | 234, 170 | 235, 171 | 236, 172 | 237, 173 | 238, 174 | 239, 175 | 240, 176 | 241, 177 | 245, 178 | 246, 179 | 247, 180 | 248, 181 | 249, 182 | 250, 183 | 252, 184 | 253, 185 | 254, 186 | 255, 187 | 258, 188 | 259, 189 | 260, 190 | 261, 191 | 263, 192 | 264, 193 | 266, 194 | 268, 195 | 269, 196 | 270, 197 | 271, 198 | 272, 199 | 274, 200 | 275, 201 | 276, 202 | 278, 203 | 280, 204 | 281, 205 | 282, 206 | 283, 207 | 284, 208 | 286, 209 | 287, 210 | 293, 211 | 294, 212 | 295, 213 | 296, 214 | 297, 215 | 298, 216 | 299, 217 | 300, 218 | 301, 219 | 302, 220 | 304, 221 | 305, 222 | 306, 223 | 307, 224 | 308, 225 | 309, 226 | 311, 227 | 312, 228 | 313, 229 | 316, 230 | 319, 231 | 320, 232 | 322, 233 | 324, 234 | 326, 235 | 328, 236 | 329, 237 | 330, 238 | 331, 239 | 334, 240 | 340, 241 | 342, 242 | 343, 243 | 344, 244 | 345, 245 | 346, 246 | 347, 247 | 348, 248 | 349, 249 | 350, 250 | 352, 251 | 354, 252 | 356, 253 | 357, 254 | 360, 255 | 361, 256 | 362, 257 | 363, 258 | 365, 259 | 366, 260 | 368 261 | ], 262 | "test": [ 263 | 0, 264 | 9, 265 | 23, 266 | 25, 267 | 28, 268 | 31, 269 | 32, 270 | 38, 271 | 39, 272 | 41, 273 | 42, 274 | 43, 275 | 47, 276 | 48, 277 | 50, 278 | 53, 279 | 57, 280 | 58, 281 | 62, 282 | 69, 283 | 70, 284 | 72, 285 | 82, 286 | 84, 287 | 87, 288 | 88, 289 | 91, 290 | 94, 291 | 95, 292 | 98, 293 | 99, 294 | 105, 295 | 115, 296 | 117, 297 | 119, 298 | 121, 299 | 123, 300 | 127, 301 | 128, 302 | 130, 303 | 131, 304 | 143, 305 | 147, 306 | 148, 307 | 151, 308 | 163, 309 | 165, 310 | 169, 311 | 172, 312 | 174, 313 | 177, 314 | 178, 315 | 180, 316 | 183, 317 | 185, 318 | 187, 319 | 192, 320 | 193, 321 | 195, 322 | 197, 323 | 201, 324 | 202, 325 | 203, 326 | 207, 327 | 209, 328 | 211, 329 | 222, 330 | 227, 331 | 242, 332 | 243, 333 | 244, 334 | 251, 335 | 256, 336 | 257, 337 | 262, 338 | 265, 339 | 267, 340 | 273, 341 | 277, 342 | 279, 343 | 285, 344 | 288, 345 | 289, 346 | 290, 347 | 291, 348 | 292, 349 | 303, 350 | 310, 351 | 314, 352 | 315, 353 | 317, 354 | 318, 355 | 321, 356 | 323, 357 | 325, 358 | 327, 359 | 332, 360 | 333, 361 | 335, 362 | 336, 363 | 337, 364 | 338, 365 | 339, 366 | 341, 367 | 351, 368 | 353, 369 | 355, 370 | 358, 371 | 359, 372 | 364, 373 | 367 374 | ] 375 | } -------------------------------------------------------------------------------- /data/split/date_understanding__template.json: -------------------------------------------------------------------------------- 1 | {"train": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 180, 181, 182, 183, 184, 185, 186, 187, 188, 198, 199, 200, 201, 202, 203, 204, 205, 206, 225, 226, 227, 228, 229, 230, 231, 232, 233, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 360, 361, 362, 363, 364, 365, 366, 367, 368], "test": [0, 1, 2, 3, 4, 5, 6, 7, 8, 27, 28, 29, 30, 31, 32, 33, 34, 35, 54, 55, 56, 57, 58, 59, 60, 61, 62, 81, 82, 83, 84, 85, 86, 87, 88, 89, 108, 109, 110, 111, 112, 113, 114, 115, 116, 171, 172, 173, 174, 175, 176, 177, 178, 179, 189, 190, 191, 192, 193, 194, 195, 196, 197, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 234, 235, 236, 237, 238, 239, 240, 241, 242, 288, 289, 290, 291, 292, 293, 294, 295, 296, 351, 352, 353, 354, 355, 356, 357, 358, 359]} -------------------------------------------------------------------------------- /data/split/last_letter_concatenation__default.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8, 11 | 10, 12 | 12, 13 | 13, 14 | 14, 15 | 15, 16 | 16, 17 | 17, 18 | 18, 19 | 19, 20 | 20, 21 | 21, 22 | 22, 23 | 24, 24 | 26, 25 | 27, 26 | 29, 27 | 30, 28 | 33, 29 | 34, 30 | 35, 31 | 37, 32 | 40, 33 | 44, 34 | 45, 35 | 46, 36 | 49, 37 | 51, 38 | 52, 39 | 54, 40 | 55, 41 | 56, 42 | 59, 43 | 60, 44 | 61, 45 | 63, 46 | 64, 47 | 65, 48 | 66, 49 | 67, 50 | 68, 51 | 71, 52 | 73, 53 | 74, 54 | 75, 55 | 76, 56 | 77, 57 | 78, 58 | 79, 59 | 80, 60 | 81, 61 | 83, 62 | 85, 63 | 89, 64 | 90, 65 | 92, 66 | 93, 67 | 96, 68 | 97, 69 | 100, 70 | 101, 71 | 102, 72 | 103, 73 | 104, 74 | 106, 75 | 107, 76 | 108, 77 | 109, 78 | 110, 79 | 111, 80 | 112, 81 | 113, 82 | 114, 83 | 116, 84 | 118, 85 | 120, 86 | 122, 87 | 124, 88 | 125, 89 | 126, 90 | 129, 91 | 132, 92 | 133, 93 | 134, 94 | 135, 95 | 136, 96 | 137, 97 | 138, 98 | 139, 99 | 140, 100 | 141, 101 | 142, 102 | 144, 103 | 145, 104 | 146, 105 | 149, 106 | 150, 107 | 152, 108 | 153, 109 | 154, 110 | 155, 111 | 156, 112 | 157, 113 | 158, 114 | 159, 115 | 160, 116 | 161, 117 | 162, 118 | 164, 119 | 166, 120 | 167, 121 | 168, 122 | 170, 123 | 171, 124 | 173, 125 | 175, 126 | 176, 127 | 179, 128 | 181, 129 | 182, 130 | 184, 131 | 186, 132 | 188, 133 | 189, 134 | 190, 135 | 191, 136 | 194, 137 | 196, 138 | 198, 139 | 199, 140 | 200, 141 | 204, 142 | 205, 143 | 206, 144 | 208, 145 | 210, 146 | 212, 147 | 213, 148 | 214, 149 | 215, 150 | 216, 151 | 217, 152 | 218, 153 | 219, 154 | 220, 155 | 221, 156 | 223, 157 | 224, 158 | 225, 159 | 226, 160 | 228, 161 | 229, 162 | 230, 163 | 231, 164 | 232, 165 | 233, 166 | 234, 167 | 235, 168 | 236, 169 | 237, 170 | 238, 171 | 239, 172 | 240, 173 | 241, 174 | 245, 175 | 246, 176 | 247, 177 | 248, 178 | 249, 179 | 250, 180 | 252, 181 | 253, 182 | 254, 183 | 255, 184 | 258, 185 | 259, 186 | 260, 187 | 261, 188 | 263, 189 | 264, 190 | 268, 191 | 269, 192 | 271, 193 | 272, 194 | 274, 195 | 275, 196 | 276, 197 | 278, 198 | 280, 199 | 281, 200 | 282, 201 | 283, 202 | 284, 203 | 286, 204 | 287, 205 | 289, 206 | 293, 207 | 294, 208 | 295, 209 | 296, 210 | 297, 211 | 298, 212 | 299, 213 | 300, 214 | 301, 215 | 302, 216 | 303, 217 | 306, 218 | 307, 219 | 308, 220 | 309, 221 | 310, 222 | 311, 223 | 312, 224 | 313, 225 | 315, 226 | 316, 227 | 317, 228 | 318, 229 | 319, 230 | 320, 231 | 322, 232 | 325, 233 | 326, 234 | 327, 235 | 328, 236 | 329, 237 | 330, 238 | 331, 239 | 332, 240 | 336, 241 | 339, 242 | 340, 243 | 342, 244 | 343, 245 | 344, 246 | 345, 247 | 346, 248 | 347, 249 | 348, 250 | 350, 251 | 351, 252 | 352, 253 | 353, 254 | 354, 255 | 355, 256 | 357, 257 | 360, 258 | 361, 259 | 362, 260 | 363, 261 | 364, 262 | 365, 263 | 366, 264 | 367, 265 | 372, 266 | 374, 267 | 375, 268 | 378, 269 | 379, 270 | 380, 271 | 381, 272 | 382, 273 | 383, 274 | 385, 275 | 386, 276 | 387, 277 | 390, 278 | 391, 279 | 392, 280 | 393, 281 | 394, 282 | 395, 283 | 399, 284 | 400, 285 | 401, 286 | 402, 287 | 403, 288 | 405, 289 | 406, 290 | 407, 291 | 408, 292 | 409, 293 | 411, 294 | 412, 295 | 413, 296 | 414, 297 | 415, 298 | 417, 299 | 418, 300 | 419, 301 | 421, 302 | 422, 303 | 424, 304 | 425, 305 | 427, 306 | 428, 307 | 429, 308 | 430, 309 | 432, 310 | 433, 311 | 435, 312 | 436, 313 | 438, 314 | 439, 315 | 440, 316 | 443, 317 | 445, 318 | 447, 319 | 450, 320 | 451, 321 | 452, 322 | 453, 323 | 455, 324 | 457, 325 | 458, 326 | 459, 327 | 461, 328 | 463, 329 | 468, 330 | 469, 331 | 471, 332 | 473, 333 | 474, 334 | 475, 335 | 476, 336 | 477, 337 | 478, 338 | 479, 339 | 481, 340 | 484, 341 | 485, 342 | 488, 343 | 489, 344 | 490, 345 | 491, 346 | 492, 347 | 493, 348 | 494, 349 | 495, 350 | 496, 351 | 497, 352 | 499 353 | ], 354 | "test": [ 355 | 0, 356 | 9, 357 | 11, 358 | 23, 359 | 25, 360 | 28, 361 | 31, 362 | 32, 363 | 36, 364 | 38, 365 | 39, 366 | 41, 367 | 42, 368 | 43, 369 | 47, 370 | 48, 371 | 50, 372 | 53, 373 | 57, 374 | 58, 375 | 62, 376 | 69, 377 | 70, 378 | 72, 379 | 82, 380 | 84, 381 | 86, 382 | 87, 383 | 88, 384 | 91, 385 | 94, 386 | 95, 387 | 98, 388 | 99, 389 | 105, 390 | 115, 391 | 117, 392 | 119, 393 | 121, 394 | 123, 395 | 127, 396 | 128, 397 | 130, 398 | 131, 399 | 143, 400 | 147, 401 | 148, 402 | 151, 403 | 163, 404 | 165, 405 | 169, 406 | 172, 407 | 174, 408 | 177, 409 | 178, 410 | 180, 411 | 183, 412 | 185, 413 | 187, 414 | 192, 415 | 193, 416 | 195, 417 | 197, 418 | 201, 419 | 202, 420 | 203, 421 | 207, 422 | 209, 423 | 211, 424 | 222, 425 | 227, 426 | 242, 427 | 243, 428 | 244, 429 | 251, 430 | 256, 431 | 257, 432 | 262, 433 | 265, 434 | 266, 435 | 267, 436 | 270, 437 | 273, 438 | 277, 439 | 279, 440 | 285, 441 | 288, 442 | 290, 443 | 291, 444 | 292, 445 | 304, 446 | 305, 447 | 314, 448 | 321, 449 | 323, 450 | 324, 451 | 333, 452 | 334, 453 | 335, 454 | 337, 455 | 338, 456 | 341, 457 | 349, 458 | 356, 459 | 358, 460 | 359, 461 | 368, 462 | 369, 463 | 370, 464 | 371, 465 | 373, 466 | 376, 467 | 377, 468 | 384, 469 | 388, 470 | 389, 471 | 396, 472 | 397, 473 | 398, 474 | 404, 475 | 410, 476 | 416, 477 | 420, 478 | 423, 479 | 426, 480 | 431, 481 | 434, 482 | 437, 483 | 441, 484 | 442, 485 | 444, 486 | 446, 487 | 448, 488 | 449, 489 | 454, 490 | 456, 491 | 460, 492 | 462, 493 | 464, 494 | 465, 495 | 466, 496 | 467, 497 | 470, 498 | 472, 499 | 480, 500 | 482, 501 | 483, 502 | 486, 503 | 487, 504 | 498 505 | ] 506 | } -------------------------------------------------------------------------------- /data/split/multiarith__default.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | 0, 4 | 1, 5 | 2, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8, 11 | 10, 12 | 12, 13 | 14, 14 | 15, 15 | 17, 16 | 18, 17 | 20, 18 | 21, 19 | 22, 20 | 26, 21 | 27, 22 | 29, 23 | 30, 24 | 31, 25 | 33, 26 | 34, 27 | 35, 28 | 37, 29 | 38, 30 | 39, 31 | 44, 32 | 45, 33 | 46, 34 | 48, 35 | 49, 36 | 50, 37 | 51, 38 | 52, 39 | 54, 40 | 55, 41 | 56, 42 | 59, 43 | 60, 44 | 61, 45 | 62, 46 | 63, 47 | 64, 48 | 65, 49 | 66, 50 | 68, 51 | 69, 52 | 71, 53 | 73, 54 | 74, 55 | 75, 56 | 76, 57 | 78, 58 | 81, 59 | 85, 60 | 88, 61 | 89, 62 | 90, 63 | 92, 64 | 93, 65 | 96, 66 | 97, 67 | 100, 68 | 101, 69 | 102, 70 | 103, 71 | 104, 72 | 105, 73 | 106, 74 | 107, 75 | 108, 76 | 112, 77 | 113, 78 | 114, 79 | 116, 80 | 118, 81 | 120, 82 | 122, 83 | 124, 84 | 126, 85 | 127, 86 | 132, 87 | 133, 88 | 134, 89 | 135, 90 | 136, 91 | 137, 92 | 138, 93 | 140, 94 | 141, 95 | 142, 96 | 144, 97 | 145, 98 | 146, 99 | 150, 100 | 153, 101 | 154, 102 | 155, 103 | 156, 104 | 157, 105 | 158, 106 | 159, 107 | 160, 108 | 162, 109 | 163, 110 | 164, 111 | 165, 112 | 167, 113 | 168, 114 | 170, 115 | 171, 116 | 172, 117 | 173, 118 | 175, 119 | 176, 120 | 178, 121 | 179, 122 | 181, 123 | 185, 124 | 186, 125 | 187, 126 | 188, 127 | 189, 128 | 190, 129 | 191, 130 | 193, 131 | 194, 132 | 195, 133 | 196, 134 | 198, 135 | 199, 136 | 200, 137 | 202, 138 | 204, 139 | 205, 140 | 206, 141 | 208, 142 | 210, 143 | 211, 144 | 212, 145 | 213, 146 | 214, 147 | 215, 148 | 217, 149 | 218, 150 | 219, 151 | 220, 152 | 221, 153 | 222, 154 | 224, 155 | 225, 156 | 229, 157 | 230, 158 | 231, 159 | 232, 160 | 233, 161 | 234, 162 | 235, 163 | 236, 164 | 238, 165 | 239, 166 | 240, 167 | 241, 168 | 242, 169 | 243, 170 | 245, 171 | 246, 172 | 247, 173 | 249, 174 | 250, 175 | 252, 176 | 253, 177 | 254, 178 | 255, 179 | 261, 180 | 262, 181 | 263, 182 | 264, 183 | 266, 184 | 268, 185 | 271, 186 | 272, 187 | 276, 188 | 278, 189 | 279, 190 | 281, 191 | 282, 192 | 283, 193 | 284, 194 | 285, 195 | 289, 196 | 290, 197 | 293, 198 | 295, 199 | 297, 200 | 298, 201 | 299, 202 | 301, 203 | 302, 204 | 303, 205 | 304, 206 | 306, 207 | 308, 208 | 309, 209 | 310, 210 | 311, 211 | 312, 212 | 313, 213 | 315, 214 | 316, 215 | 318, 216 | 319, 217 | 320, 218 | 322, 219 | 325, 220 | 329, 221 | 330, 222 | 332, 223 | 333, 224 | 334, 225 | 336, 226 | 337, 227 | 338, 228 | 339, 229 | 340, 230 | 341, 231 | 342, 232 | 343, 233 | 344, 234 | 345, 235 | 346, 236 | 347, 237 | 348, 238 | 350, 239 | 351, 240 | 353, 241 | 354, 242 | 355, 243 | 356, 244 | 357, 245 | 358, 246 | 361, 247 | 362, 248 | 364, 249 | 365, 250 | 366, 251 | 367, 252 | 369, 253 | 372, 254 | 375, 255 | 376, 256 | 378, 257 | 379, 258 | 380, 259 | 382, 260 | 384, 261 | 385, 262 | 386, 263 | 389, 264 | 390, 265 | 391, 266 | 392, 267 | 393, 268 | 394, 269 | 395, 270 | 397, 271 | 399, 272 | 400, 273 | 401, 274 | 402, 275 | 403, 276 | 404, 277 | 405, 278 | 406, 279 | 407, 280 | 408, 281 | 409, 282 | 412, 283 | 414, 284 | 415, 285 | 417, 286 | 418, 287 | 419, 288 | 420, 289 | 421, 290 | 422, 291 | 424, 292 | 425, 293 | 426, 294 | 427, 295 | 428, 296 | 429, 297 | 432, 298 | 433, 299 | 434, 300 | 435, 301 | 436, 302 | 437, 303 | 438, 304 | 439, 305 | 440, 306 | 441, 307 | 443, 308 | 446, 309 | 447, 310 | 449, 311 | 452, 312 | 454, 313 | 455, 314 | 456, 315 | 457, 316 | 458, 317 | 463, 318 | 464, 319 | 466, 320 | 467, 321 | 468, 322 | 469, 323 | 470, 324 | 471, 325 | 473, 326 | 474, 327 | 475, 328 | 476, 329 | 477, 330 | 478, 331 | 479, 332 | 480, 333 | 481, 334 | 482, 335 | 484, 336 | 485, 337 | 487, 338 | 489, 339 | 490, 340 | 491, 341 | 492, 342 | 493, 343 | 494, 344 | 495, 345 | 496, 346 | 499, 347 | 500, 348 | 501, 349 | 502, 350 | 503, 351 | 504, 352 | 505, 353 | 506, 354 | 507, 355 | 508, 356 | 514, 357 | 515, 358 | 517, 359 | 518, 360 | 519, 361 | 520, 362 | 521, 363 | 523, 364 | 524, 365 | 525, 366 | 526, 367 | 527, 368 | 529, 369 | 531, 370 | 532, 371 | 533, 372 | 534, 373 | 536, 374 | 538, 375 | 539, 376 | 540, 377 | 541, 378 | 545, 379 | 546, 380 | 547, 381 | 548, 382 | 549, 383 | 552, 384 | 553, 385 | 554, 386 | 555, 387 | 556, 388 | 557, 389 | 558, 390 | 560, 391 | 561, 392 | 562, 393 | 564, 394 | 565, 395 | 568, 396 | 569, 397 | 570, 398 | 571, 399 | 572, 400 | 573, 401 | 575, 402 | 576, 403 | 577, 404 | 578, 405 | 579, 406 | 580, 407 | 581, 408 | 583, 409 | 584, 410 | 585, 411 | 586, 412 | 588, 413 | 589, 414 | 590, 415 | 591, 416 | 592, 417 | 593, 418 | 594, 419 | 595, 420 | 596, 421 | 597, 422 | 598 423 | ], 424 | "test": [ 425 | 3, 426 | 9, 427 | 11, 428 | 13, 429 | 16, 430 | 19, 431 | 23, 432 | 24, 433 | 25, 434 | 28, 435 | 32, 436 | 36, 437 | 40, 438 | 41, 439 | 42, 440 | 43, 441 | 47, 442 | 53, 443 | 57, 444 | 58, 445 | 67, 446 | 70, 447 | 72, 448 | 77, 449 | 79, 450 | 80, 451 | 82, 452 | 83, 453 | 84, 454 | 86, 455 | 87, 456 | 91, 457 | 94, 458 | 95, 459 | 98, 460 | 99, 461 | 109, 462 | 110, 463 | 111, 464 | 115, 465 | 117, 466 | 119, 467 | 121, 468 | 123, 469 | 125, 470 | 128, 471 | 129, 472 | 130, 473 | 131, 474 | 139, 475 | 143, 476 | 147, 477 | 148, 478 | 149, 479 | 151, 480 | 152, 481 | 161, 482 | 166, 483 | 169, 484 | 174, 485 | 177, 486 | 180, 487 | 182, 488 | 183, 489 | 184, 490 | 192, 491 | 197, 492 | 201, 493 | 203, 494 | 207, 495 | 209, 496 | 216, 497 | 223, 498 | 226, 499 | 227, 500 | 228, 501 | 237, 502 | 244, 503 | 248, 504 | 251, 505 | 256, 506 | 257, 507 | 258, 508 | 259, 509 | 260, 510 | 265, 511 | 267, 512 | 269, 513 | 270, 514 | 273, 515 | 274, 516 | 275, 517 | 277, 518 | 280, 519 | 286, 520 | 287, 521 | 288, 522 | 291, 523 | 292, 524 | 294, 525 | 296, 526 | 300, 527 | 305, 528 | 307, 529 | 314, 530 | 317, 531 | 321, 532 | 323, 533 | 324, 534 | 326, 535 | 327, 536 | 328, 537 | 331, 538 | 335, 539 | 349, 540 | 352, 541 | 359, 542 | 360, 543 | 363, 544 | 368, 545 | 370, 546 | 371, 547 | 373, 548 | 374, 549 | 377, 550 | 381, 551 | 383, 552 | 387, 553 | 388, 554 | 396, 555 | 398, 556 | 410, 557 | 411, 558 | 413, 559 | 416, 560 | 423, 561 | 430, 562 | 431, 563 | 442, 564 | 444, 565 | 445, 566 | 448, 567 | 450, 568 | 451, 569 | 453, 570 | 459, 571 | 460, 572 | 461, 573 | 462, 574 | 465, 575 | 472, 576 | 483, 577 | 486, 578 | 488, 579 | 497, 580 | 498, 581 | 509, 582 | 510, 583 | 511, 584 | 512, 585 | 513, 586 | 516, 587 | 522, 588 | 528, 589 | 530, 590 | 535, 591 | 537, 592 | 542, 593 | 543, 594 | 544, 595 | 550, 596 | 551, 597 | 559, 598 | 563, 599 | 566, 600 | 567, 601 | 574, 602 | 582, 603 | 587, 604 | 599 605 | ] 606 | } -------------------------------------------------------------------------------- /data/split/multiarith__template.json: -------------------------------------------------------------------------------- 1 | {"train": [0, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 42, 43, 44, 45, 47, 48, 49, 50, 52, 53, 54, 55, 57, 58, 59, 60, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 79, 80, 81, 82, 83, 84, 85, 86, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 102, 103, 104, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 130, 131, 133, 134, 135, 137, 139, 143, 144, 145, 147, 149, 150, 152, 153, 154, 156, 158, 159, 160, 162, 163, 164, 165, 166, 168, 169, 170, 172, 173, 174, 175, 178, 179, 180, 181, 182, 184, 186, 188, 189, 190, 191, 193, 194, 195, 197, 198, 200, 201, 202, 203, 205, 208, 210, 211, 213, 214, 215, 216, 217, 218, 220, 223, 224, 225, 226, 227, 228, 229, 230, 232, 233, 234, 235, 236, 237, 238, 240, 241, 242, 243, 244, 247, 249, 250, 253, 254, 255, 258, 259, 260, 261, 262, 264, 265, 268, 269, 270, 271, 272, 273, 274, 275, 277, 278, 280, 281, 282, 284, 285, 286, 289, 290, 291, 292, 294, 295, 296, 297, 300, 301, 302, 303, 306, 308, 310, 312, 314, 315, 317, 318, 319, 323, 324, 325, 328, 329, 330, 331, 332, 333, 334, 335, 336, 339, 342, 343, 344, 345, 349, 352, 353, 354, 355, 359, 361, 362, 367, 368, 370, 371, 372, 374, 375, 376, 380, 381, 384, 389, 390, 393, 394, 395, 396, 397, 398, 404, 405, 406, 409, 411, 412, 416, 417, 418, 419, 421, 422, 423, 425, 427, 428, 431, 433, 434, 435, 436, 437, 438, 439, 441, 442, 443, 445, 446, 447, 448, 449, 453, 454, 457, 458, 459, 460, 461, 462, 463, 464, 465, 469, 471, 472, 473, 474, 475, 476, 478, 479, 482, 483, 484, 485, 487, 489, 490, 491, 492, 493, 494, 495, 496, 498, 500, 501, 502, 504, 507, 509, 510, 511, 513, 514, 515, 516, 517, 520, 522, 524, 525, 526, 527, 529, 530, 532, 535, 536, 538, 540, 543, 545, 546, 548, 549, 550, 551, 553, 554, 555, 557, 559, 564, 565, 566, 567, 568, 569, 570, 573, 574, 575, 577, 581, 584, 585, 586, 588, 589, 590, 593, 595, 598], "test": [9, 10, 15, 25, 38, 41, 46, 51, 56, 61, 77, 78, 87, 100, 101, 105, 116, 128, 129, 132, 136, 138, 140, 141, 142, 146, 148, 151, 155, 157, 161, 167, 171, 176, 177, 183, 185, 187, 192, 196, 199, 204, 206, 207, 209, 212, 219, 221, 222, 231, 239, 245, 246, 248, 251, 252, 256, 257, 263, 266, 267, 276, 279, 283, 287, 288, 293, 298, 299, 304, 305, 307, 309, 311, 313, 316, 320, 321, 322, 326, 327, 337, 338, 340, 341, 346, 347, 348, 350, 351, 356, 357, 358, 360, 363, 364, 365, 366, 369, 373, 377, 378, 379, 382, 383, 385, 386, 387, 388, 391, 392, 399, 400, 401, 402, 403, 407, 408, 410, 413, 414, 415, 420, 424, 426, 429, 430, 432, 440, 444, 450, 451, 452, 455, 456, 466, 467, 468, 470, 477, 480, 481, 486, 488, 497, 499, 503, 505, 506, 508, 512, 518, 519, 521, 523, 528, 531, 533, 534, 537, 539, 541, 542, 544, 547, 552, 556, 558, 560, 561, 562, 563, 571, 572, 576, 578, 579, 580, 582, 583, 587, 591, 592, 594, 596, 597, 599]} -------------------------------------------------------------------------------- /data/split/single_eq__default.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | 1, 4 | 2, 5 | 3, 6 | 4, 7 | 5, 8 | 6, 9 | 7, 10 | 8, 11 | 10, 12 | 11, 13 | 12, 14 | 13, 15 | 14, 16 | 15, 17 | 16, 18 | 17, 19 | 18, 20 | 19, 21 | 20, 22 | 21, 23 | 22, 24 | 24, 25 | 26, 26 | 27, 27 | 29, 28 | 30, 29 | 33, 30 | 34, 31 | 35, 32 | 37, 33 | 40, 34 | 44, 35 | 45, 36 | 46, 37 | 49, 38 | 51, 39 | 52, 40 | 54, 41 | 55, 42 | 56, 43 | 59, 44 | 60, 45 | 61, 46 | 63, 47 | 64, 48 | 65, 49 | 66, 50 | 67, 51 | 68, 52 | 71, 53 | 73, 54 | 74, 55 | 75, 56 | 76, 57 | 77, 58 | 78, 59 | 79, 60 | 80, 61 | 81, 62 | 83, 63 | 85, 64 | 89, 65 | 90, 66 | 92, 67 | 93, 68 | 96, 69 | 97, 70 | 100, 71 | 101, 72 | 102, 73 | 103, 74 | 104, 75 | 106, 76 | 107, 77 | 108, 78 | 109, 79 | 110, 80 | 111, 81 | 112, 82 | 113, 83 | 114, 84 | 116, 85 | 118, 86 | 120, 87 | 122, 88 | 124, 89 | 125, 90 | 126, 91 | 129, 92 | 132, 93 | 133, 94 | 134, 95 | 135, 96 | 136, 97 | 137, 98 | 138, 99 | 139, 100 | 140, 101 | 141, 102 | 142, 103 | 144, 104 | 145, 105 | 146, 106 | 149, 107 | 150, 108 | 152, 109 | 153, 110 | 154, 111 | 155, 112 | 156, 113 | 157, 114 | 158, 115 | 159, 116 | 160, 117 | 161, 118 | 162, 119 | 164, 120 | 166, 121 | 167, 122 | 168, 123 | 170, 124 | 171, 125 | 173, 126 | 175, 127 | 176, 128 | 179, 129 | 181, 130 | 182, 131 | 184, 132 | 186, 133 | 188, 134 | 189, 135 | 190, 136 | 191, 137 | 194, 138 | 196, 139 | 198, 140 | 199, 141 | 200, 142 | 204, 143 | 205, 144 | 206, 145 | 208, 146 | 210, 147 | 212, 148 | 213, 149 | 214, 150 | 215, 151 | 216, 152 | 217, 153 | 218, 154 | 219, 155 | 220, 156 | 221, 157 | 223, 158 | 224, 159 | 225, 160 | 226, 161 | 228, 162 | 229, 163 | 230, 164 | 231, 165 | 232, 166 | 233, 167 | 234, 168 | 235, 169 | 236, 170 | 237, 171 | 238, 172 | 239, 173 | 240, 174 | 241, 175 | 245, 176 | 246, 177 | 247, 178 | 248, 179 | 249, 180 | 250, 181 | 252, 182 | 253, 183 | 254, 184 | 255, 185 | 258, 186 | 259, 187 | 260, 188 | 261, 189 | 263, 190 | 264, 191 | 268, 192 | 269, 193 | 271, 194 | 272, 195 | 274, 196 | 275, 197 | 276, 198 | 278, 199 | 280, 200 | 281, 201 | 282, 202 | 283, 203 | 284, 204 | 286, 205 | 287, 206 | 289, 207 | 293, 208 | 294, 209 | 295, 210 | 296, 211 | 297, 212 | 298, 213 | 299, 214 | 300, 215 | 301, 216 | 302, 217 | 303, 218 | 306, 219 | 307, 220 | 308, 221 | 309, 222 | 310, 223 | 311, 224 | 312, 225 | 313, 226 | 315, 227 | 316, 228 | 317, 229 | 318, 230 | 319, 231 | 320, 232 | 322, 233 | 325, 234 | 326, 235 | 327, 236 | 328, 237 | 329, 238 | 330, 239 | 331, 240 | 332, 241 | 336, 242 | 339, 243 | 340, 244 | 342, 245 | 343, 246 | 344, 247 | 345, 248 | 346, 249 | 347, 250 | 348, 251 | 350, 252 | 351, 253 | 352, 254 | 353, 255 | 354, 256 | 355, 257 | 357, 258 | 360, 259 | 361, 260 | 362, 261 | 363, 262 | 364, 263 | 365, 264 | 366, 265 | 367, 266 | 371, 267 | 372, 268 | 374, 269 | 375, 270 | 377, 271 | 378, 272 | 379, 273 | 380, 274 | 381, 275 | 383, 276 | 385, 277 | 386, 278 | 387, 279 | 390, 280 | 391, 281 | 392, 282 | 393, 283 | 394, 284 | 395, 285 | 399, 286 | 400, 287 | 401, 288 | 405, 289 | 406, 290 | 407, 291 | 408, 292 | 409, 293 | 410, 294 | 411, 295 | 412, 296 | 413, 297 | 414, 298 | 415, 299 | 416, 300 | 417, 301 | 418, 302 | 419, 303 | 421, 304 | 424, 305 | 426, 306 | 427, 307 | 428, 308 | 429, 309 | 430, 310 | 432, 311 | 434, 312 | 435, 313 | 436, 314 | 437, 315 | 438, 316 | 439, 317 | 440, 318 | 442, 319 | 443, 320 | 445, 321 | 446, 322 | 447, 323 | 450, 324 | 452, 325 | 454, 326 | 456, 327 | 457, 328 | 459, 329 | 460, 330 | 462, 331 | 465, 332 | 466, 333 | 468, 334 | 474, 335 | 475, 336 | 476, 337 | 478, 338 | 479, 339 | 480, 340 | 481, 341 | 482, 342 | 483, 343 | 484, 344 | 485, 345 | 487, 346 | 489, 347 | 493, 348 | 496, 349 | 497, 350 | 498, 351 | 499, 352 | 500, 353 | 501, 354 | 502, 355 | 503, 356 | 504, 357 | 505, 358 | 507 359 | ], 360 | "test": [ 361 | 0, 362 | 9, 363 | 23, 364 | 25, 365 | 28, 366 | 31, 367 | 32, 368 | 36, 369 | 38, 370 | 39, 371 | 41, 372 | 42, 373 | 43, 374 | 47, 375 | 48, 376 | 50, 377 | 53, 378 | 57, 379 | 58, 380 | 62, 381 | 69, 382 | 70, 383 | 72, 384 | 82, 385 | 84, 386 | 86, 387 | 87, 388 | 88, 389 | 91, 390 | 94, 391 | 95, 392 | 98, 393 | 99, 394 | 105, 395 | 115, 396 | 117, 397 | 119, 398 | 121, 399 | 123, 400 | 127, 401 | 128, 402 | 130, 403 | 131, 404 | 143, 405 | 147, 406 | 148, 407 | 151, 408 | 163, 409 | 165, 410 | 169, 411 | 172, 412 | 174, 413 | 177, 414 | 178, 415 | 180, 416 | 183, 417 | 185, 418 | 187, 419 | 192, 420 | 193, 421 | 195, 422 | 197, 423 | 201, 424 | 202, 425 | 203, 426 | 207, 427 | 209, 428 | 211, 429 | 222, 430 | 227, 431 | 242, 432 | 243, 433 | 244, 434 | 251, 435 | 256, 436 | 257, 437 | 262, 438 | 265, 439 | 266, 440 | 267, 441 | 270, 442 | 273, 443 | 277, 444 | 279, 445 | 285, 446 | 288, 447 | 290, 448 | 291, 449 | 292, 450 | 304, 451 | 305, 452 | 314, 453 | 321, 454 | 323, 455 | 324, 456 | 333, 457 | 334, 458 | 335, 459 | 337, 460 | 338, 461 | 341, 462 | 349, 463 | 356, 464 | 358, 465 | 359, 466 | 368, 467 | 369, 468 | 370, 469 | 373, 470 | 376, 471 | 382, 472 | 384, 473 | 388, 474 | 389, 475 | 396, 476 | 397, 477 | 398, 478 | 402, 479 | 403, 480 | 404, 481 | 420, 482 | 422, 483 | 423, 484 | 425, 485 | 431, 486 | 433, 487 | 441, 488 | 444, 489 | 448, 490 | 449, 491 | 451, 492 | 453, 493 | 455, 494 | 458, 495 | 461, 496 | 463, 497 | 464, 498 | 467, 499 | 469, 500 | 470, 501 | 471, 502 | 472, 503 | 473, 504 | 477, 505 | 486, 506 | 488, 507 | 490, 508 | 491, 509 | 492, 510 | 494, 511 | 495, 512 | 506 513 | ] 514 | } -------------------------------------------------------------------------------- /data/split/svamp__default.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | 0, 4 | 1, 5 | 2, 6 | 3, 7 | 4, 8 | 5, 9 | 6, 10 | 7, 11 | 8, 12 | 10, 13 | 12, 14 | 13, 15 | 14, 16 | 15, 17 | 16, 18 | 17, 19 | 18, 20 | 20, 21 | 21, 22 | 22, 23 | 26, 24 | 27, 25 | 30, 26 | 31, 27 | 34, 28 | 35, 29 | 37, 30 | 38, 31 | 39, 32 | 40, 33 | 45, 34 | 46, 35 | 48, 36 | 49, 37 | 50, 38 | 51, 39 | 52, 40 | 54, 41 | 55, 42 | 56, 43 | 59, 44 | 60, 45 | 62, 46 | 63, 47 | 64, 48 | 65, 49 | 66, 50 | 68, 51 | 69, 52 | 71, 53 | 74, 54 | 75, 55 | 76, 56 | 77, 57 | 78, 58 | 79, 59 | 81, 60 | 83, 61 | 85, 62 | 89, 63 | 90, 64 | 92, 65 | 93, 66 | 96, 67 | 97, 68 | 100, 69 | 101, 70 | 102, 71 | 103, 72 | 104, 73 | 105, 74 | 106, 75 | 107, 76 | 108, 77 | 109, 78 | 112, 79 | 113, 80 | 116, 81 | 118, 82 | 120, 83 | 122, 84 | 124, 85 | 125, 86 | 126, 87 | 127, 88 | 132, 89 | 133, 90 | 134, 91 | 135, 92 | 137, 93 | 140, 94 | 141, 95 | 142, 96 | 144, 97 | 145, 98 | 146, 99 | 150, 100 | 153, 101 | 154, 102 | 155, 103 | 156, 104 | 157, 105 | 158, 106 | 159, 107 | 160, 108 | 161, 109 | 162, 110 | 163, 111 | 164, 112 | 165, 113 | 167, 114 | 170, 115 | 171, 116 | 172, 117 | 173, 118 | 175, 119 | 178, 120 | 179, 121 | 181, 122 | 185, 123 | 186, 124 | 187, 125 | 188, 126 | 190, 127 | 191, 128 | 193, 129 | 194, 130 | 195, 131 | 196, 132 | 198, 133 | 200, 134 | 202, 135 | 204, 136 | 205, 137 | 206, 138 | 208, 139 | 210, 140 | 211, 141 | 213, 142 | 214, 143 | 215, 144 | 218, 145 | 219, 146 | 220, 147 | 221, 148 | 222, 149 | 223, 150 | 224, 151 | 225, 152 | 229, 153 | 230, 154 | 231, 155 | 233, 156 | 235, 157 | 236, 158 | 237, 159 | 238, 160 | 239, 161 | 240, 162 | 241, 163 | 242, 164 | 243, 165 | 245, 166 | 246, 167 | 247, 168 | 249, 169 | 250, 170 | 251, 171 | 252, 172 | 253, 173 | 255, 174 | 258, 175 | 259, 176 | 261, 177 | 262, 178 | 263, 179 | 264, 180 | 266, 181 | 267, 182 | 268, 183 | 270, 184 | 271, 185 | 272, 186 | 276, 187 | 278, 188 | 279, 189 | 281, 190 | 282, 191 | 283, 192 | 284, 193 | 285, 194 | 293, 195 | 294, 196 | 295, 197 | 298, 198 | 299, 199 | 300, 200 | 301, 201 | 302, 202 | 303, 203 | 304, 204 | 306, 205 | 308, 206 | 309, 207 | 310, 208 | 311, 209 | 312, 210 | 313, 211 | 315, 212 | 316, 213 | 317, 214 | 318, 215 | 319, 216 | 320, 217 | 322, 218 | 325, 219 | 327, 220 | 329, 221 | 330, 222 | 331, 223 | 332, 224 | 333, 225 | 334, 226 | 336, 227 | 337, 228 | 338, 229 | 339, 230 | 340, 231 | 342, 232 | 343, 233 | 344, 234 | 345, 235 | 346, 236 | 348, 237 | 350, 238 | 351, 239 | 352, 240 | 353, 241 | 354, 242 | 355, 243 | 356, 244 | 357, 245 | 358, 246 | 360, 247 | 361, 248 | 362, 249 | 363, 250 | 364, 251 | 365, 252 | 366, 253 | 367, 254 | 369, 255 | 372, 256 | 374, 257 | 375, 258 | 378, 259 | 379, 260 | 380, 261 | 382, 262 | 384, 263 | 385, 264 | 386, 265 | 389, 266 | 390, 267 | 391, 268 | 392, 269 | 395, 270 | 397, 271 | 399, 272 | 400, 273 | 401, 274 | 402, 275 | 403, 276 | 404, 277 | 405, 278 | 406, 279 | 407, 280 | 408, 281 | 409, 282 | 412, 283 | 413, 284 | 414, 285 | 415, 286 | 416, 287 | 417, 288 | 418, 289 | 419, 290 | 420, 291 | 422, 292 | 424, 293 | 425, 294 | 426, 295 | 427, 296 | 428, 297 | 432, 298 | 433, 299 | 434, 300 | 435, 301 | 436, 302 | 437, 303 | 439, 304 | 440, 305 | 441, 306 | 443, 307 | 447, 308 | 449, 309 | 451, 310 | 452, 311 | 453, 312 | 454, 313 | 455, 314 | 456, 315 | 457, 316 | 458, 317 | 460, 318 | 462, 319 | 463, 320 | 465, 321 | 466, 322 | 467, 323 | 468, 324 | 470, 325 | 471, 326 | 473, 327 | 474, 328 | 475, 329 | 476, 330 | 477, 331 | 478, 332 | 479, 333 | 480, 334 | 481, 335 | 482, 336 | 483, 337 | 484, 338 | 485, 339 | 487, 340 | 489, 341 | 490, 342 | 491, 343 | 492, 344 | 493, 345 | 494, 346 | 495, 347 | 496, 348 | 497, 349 | 498, 350 | 499, 351 | 500, 352 | 501, 353 | 502, 354 | 503, 355 | 504, 356 | 505, 357 | 506, 358 | 509, 359 | 511, 360 | 513, 361 | 514, 362 | 516, 363 | 517, 364 | 518, 365 | 519, 366 | 520, 367 | 521, 368 | 522, 369 | 523, 370 | 524, 371 | 526, 372 | 527, 373 | 529, 374 | 530, 375 | 531, 376 | 532, 377 | 533, 378 | 534, 379 | 535, 380 | 536, 381 | 538, 382 | 540, 383 | 542, 384 | 545, 385 | 546, 386 | 548, 387 | 549, 388 | 553, 389 | 556, 390 | 557, 391 | 558, 392 | 561, 393 | 563, 394 | 564, 395 | 566, 396 | 567, 397 | 568, 398 | 569, 399 | 570, 400 | 571, 401 | 572, 402 | 573, 403 | 575, 404 | 576, 405 | 577, 406 | 578, 407 | 580, 408 | 582, 409 | 583, 410 | 584, 411 | 585, 412 | 586, 413 | 587, 414 | 588, 415 | 590, 416 | 592, 417 | 593, 418 | 596, 419 | 597, 420 | 601, 421 | 602, 422 | 603, 423 | 604, 424 | 605, 425 | 608, 426 | 609, 427 | 612, 428 | 613, 429 | 614, 430 | 615, 431 | 618, 432 | 619, 433 | 620, 434 | 622, 435 | 624, 436 | 625, 437 | 626, 438 | 627, 439 | 628, 440 | 630, 441 | 631, 442 | 632, 443 | 634, 444 | 635, 445 | 636, 446 | 638, 447 | 640, 448 | 641, 449 | 642, 450 | 643, 451 | 644, 452 | 646, 453 | 647, 454 | 648, 455 | 649, 456 | 650, 457 | 651, 458 | 652, 459 | 653, 460 | 654, 461 | 655, 462 | 656, 463 | 657, 464 | 658, 465 | 661, 466 | 662, 467 | 664, 468 | 665, 469 | 666, 470 | 667, 471 | 669, 472 | 670, 473 | 671, 474 | 672, 475 | 674, 476 | 676, 477 | 678, 478 | 679, 479 | 680, 480 | 681, 481 | 682, 482 | 683, 483 | 685, 484 | 686, 485 | 687, 486 | 688, 487 | 689, 488 | 691, 489 | 692, 490 | 693, 491 | 695, 492 | 698, 493 | 700, 494 | 702, 495 | 703, 496 | 704, 497 | 706, 498 | 708, 499 | 710, 500 | 711, 501 | 712, 502 | 713, 503 | 715, 504 | 717, 505 | 718, 506 | 720, 507 | 721, 508 | 722, 509 | 724, 510 | 725, 511 | 728, 512 | 729, 513 | 731, 514 | 732, 515 | 733, 516 | 736, 517 | 737, 518 | 738, 519 | 740, 520 | 741, 521 | 742, 522 | 743, 523 | 744, 524 | 745, 525 | 746, 526 | 748, 527 | 750, 528 | 751, 529 | 752, 530 | 753, 531 | 757, 532 | 758, 533 | 759, 534 | 760, 535 | 761, 536 | 762, 537 | 765, 538 | 766, 539 | 767, 540 | 768, 541 | 769, 542 | 771, 543 | 773, 544 | 775, 545 | 776, 546 | 780, 547 | 781, 548 | 783, 549 | 784, 550 | 785, 551 | 786, 552 | 787, 553 | 788, 554 | 789, 555 | 790, 556 | 792, 557 | 793, 558 | 794, 559 | 795, 560 | 796, 561 | 798, 562 | 799, 563 | 801, 564 | 803, 565 | 805, 566 | 807, 567 | 808, 568 | 809, 569 | 811, 570 | 812, 571 | 813, 572 | 814, 573 | 815, 574 | 817, 575 | 818, 576 | 819, 577 | 820, 578 | 821, 579 | 822, 580 | 823, 581 | 825, 582 | 826, 583 | 827, 584 | 829, 585 | 830, 586 | 831, 587 | 832, 588 | 834, 589 | 836, 590 | 837, 591 | 843, 592 | 844, 593 | 848, 594 | 851, 595 | 852, 596 | 855, 597 | 856, 598 | 857, 599 | 858, 600 | 859, 601 | 860, 602 | 862, 603 | 863, 604 | 864, 605 | 865, 606 | 867, 607 | 873, 608 | 874, 609 | 875, 610 | 876, 611 | 877, 612 | 878, 613 | 879, 614 | 880, 615 | 883, 616 | 884, 617 | 885, 618 | 886, 619 | 887, 620 | 890, 621 | 891, 622 | 892, 623 | 893, 624 | 894, 625 | 895, 626 | 898, 627 | 899, 628 | 900, 629 | 902, 630 | 903, 631 | 904, 632 | 905, 633 | 906, 634 | 907, 635 | 908, 636 | 909, 637 | 911, 638 | 912, 639 | 913, 640 | 914, 641 | 915, 642 | 918, 643 | 919, 644 | 920, 645 | 921, 646 | 922, 647 | 923, 648 | 924, 649 | 925, 650 | 926, 651 | 927, 652 | 928, 653 | 929, 654 | 931, 655 | 933, 656 | 934, 657 | 938, 658 | 939, 659 | 940, 660 | 942, 661 | 944, 662 | 945, 663 | 946, 664 | 948, 665 | 950, 666 | 951, 667 | 952, 668 | 955, 669 | 956, 670 | 957, 671 | 958, 672 | 959, 673 | 960, 674 | 963, 675 | 964, 676 | 967, 677 | 968, 678 | 969, 679 | 970, 680 | 971, 681 | 973, 682 | 975, 683 | 977, 684 | 978, 685 | 979, 686 | 981, 687 | 982, 688 | 984, 689 | 985, 690 | 986, 691 | 987, 692 | 988, 693 | 989, 694 | 990, 695 | 991, 696 | 992, 697 | 993, 698 | 995, 699 | 996, 700 | 997, 701 | 998, 702 | 999 703 | ], 704 | "test": [ 705 | 9, 706 | 11, 707 | 19, 708 | 23, 709 | 24, 710 | 25, 711 | 28, 712 | 29, 713 | 32, 714 | 33, 715 | 36, 716 | 41, 717 | 42, 718 | 43, 719 | 44, 720 | 47, 721 | 53, 722 | 57, 723 | 58, 724 | 61, 725 | 67, 726 | 70, 727 | 72, 728 | 73, 729 | 80, 730 | 82, 731 | 84, 732 | 86, 733 | 87, 734 | 88, 735 | 91, 736 | 94, 737 | 95, 738 | 98, 739 | 99, 740 | 110, 741 | 111, 742 | 114, 743 | 115, 744 | 117, 745 | 119, 746 | 121, 747 | 123, 748 | 128, 749 | 129, 750 | 130, 751 | 131, 752 | 136, 753 | 138, 754 | 139, 755 | 143, 756 | 147, 757 | 148, 758 | 149, 759 | 151, 760 | 152, 761 | 166, 762 | 168, 763 | 169, 764 | 174, 765 | 176, 766 | 177, 767 | 180, 768 | 182, 769 | 183, 770 | 184, 771 | 189, 772 | 192, 773 | 197, 774 | 199, 775 | 201, 776 | 203, 777 | 207, 778 | 209, 779 | 212, 780 | 216, 781 | 217, 782 | 226, 783 | 227, 784 | 228, 785 | 232, 786 | 234, 787 | 244, 788 | 248, 789 | 254, 790 | 256, 791 | 257, 792 | 260, 793 | 265, 794 | 269, 795 | 273, 796 | 274, 797 | 275, 798 | 277, 799 | 280, 800 | 286, 801 | 287, 802 | 288, 803 | 289, 804 | 290, 805 | 291, 806 | 292, 807 | 296, 808 | 297, 809 | 305, 810 | 307, 811 | 314, 812 | 321, 813 | 323, 814 | 324, 815 | 326, 816 | 328, 817 | 335, 818 | 341, 819 | 347, 820 | 349, 821 | 359, 822 | 368, 823 | 370, 824 | 371, 825 | 373, 826 | 376, 827 | 377, 828 | 381, 829 | 383, 830 | 387, 831 | 388, 832 | 393, 833 | 394, 834 | 396, 835 | 398, 836 | 410, 837 | 411, 838 | 421, 839 | 423, 840 | 429, 841 | 430, 842 | 431, 843 | 438, 844 | 442, 845 | 444, 846 | 445, 847 | 446, 848 | 448, 849 | 450, 850 | 459, 851 | 461, 852 | 464, 853 | 469, 854 | 472, 855 | 486, 856 | 488, 857 | 507, 858 | 508, 859 | 510, 860 | 512, 861 | 515, 862 | 525, 863 | 528, 864 | 537, 865 | 539, 866 | 541, 867 | 543, 868 | 544, 869 | 547, 870 | 550, 871 | 551, 872 | 552, 873 | 554, 874 | 555, 875 | 559, 876 | 560, 877 | 562, 878 | 565, 879 | 574, 880 | 579, 881 | 581, 882 | 589, 883 | 591, 884 | 594, 885 | 595, 886 | 598, 887 | 599, 888 | 600, 889 | 606, 890 | 607, 891 | 610, 892 | 611, 893 | 616, 894 | 617, 895 | 621, 896 | 623, 897 | 629, 898 | 633, 899 | 637, 900 | 639, 901 | 645, 902 | 659, 903 | 660, 904 | 663, 905 | 668, 906 | 673, 907 | 675, 908 | 677, 909 | 684, 910 | 690, 911 | 694, 912 | 696, 913 | 697, 914 | 699, 915 | 701, 916 | 705, 917 | 707, 918 | 709, 919 | 714, 920 | 716, 921 | 719, 922 | 723, 923 | 726, 924 | 727, 925 | 730, 926 | 734, 927 | 735, 928 | 739, 929 | 747, 930 | 749, 931 | 754, 932 | 755, 933 | 756, 934 | 763, 935 | 764, 936 | 770, 937 | 772, 938 | 774, 939 | 777, 940 | 778, 941 | 779, 942 | 782, 943 | 791, 944 | 797, 945 | 800, 946 | 802, 947 | 804, 948 | 806, 949 | 810, 950 | 816, 951 | 824, 952 | 828, 953 | 833, 954 | 835, 955 | 838, 956 | 839, 957 | 840, 958 | 841, 959 | 842, 960 | 845, 961 | 846, 962 | 847, 963 | 849, 964 | 850, 965 | 853, 966 | 854, 967 | 861, 968 | 866, 969 | 868, 970 | 869, 971 | 870, 972 | 871, 973 | 872, 974 | 881, 975 | 882, 976 | 888, 977 | 889, 978 | 896, 979 | 897, 980 | 901, 981 | 910, 982 | 916, 983 | 917, 984 | 930, 985 | 932, 986 | 935, 987 | 936, 988 | 937, 989 | 941, 990 | 943, 991 | 947, 992 | 949, 993 | 953, 994 | 954, 995 | 961, 996 | 962, 997 | 965, 998 | 966, 999 | 972, 1000 | 974, 1001 | 976, 1002 | 980, 1003 | 983, 1004 | 994 1005 | ] 1006 | } -------------------------------------------------------------------------------- /data/split/tracking_shuffled_objects__default.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | 0, 4 | 1, 5 | 2, 6 | 3, 7 | 4, 8 | 5, 9 | 6, 10 | 7, 11 | 8, 12 | 10, 13 | 12, 14 | 13, 15 | 14, 16 | 15, 17 | 16, 18 | 17, 19 | 18, 20 | 20, 21 | 21, 22 | 22, 23 | 26, 24 | 27, 25 | 29, 26 | 30, 27 | 31, 28 | 34, 29 | 35, 30 | 37, 31 | 38, 32 | 39, 33 | 40, 34 | 45, 35 | 46, 36 | 48, 37 | 49, 38 | 50, 39 | 51, 40 | 52, 41 | 54, 42 | 55, 43 | 56, 44 | 59, 45 | 60, 46 | 62, 47 | 63, 48 | 64, 49 | 65, 50 | 66, 51 | 68, 52 | 69, 53 | 71, 54 | 74, 55 | 75, 56 | 76, 57 | 77, 58 | 78, 59 | 79, 60 | 81, 61 | 83, 62 | 85, 63 | 89, 64 | 90, 65 | 92, 66 | 93, 67 | 96, 68 | 97, 69 | 100, 70 | 101, 71 | 102, 72 | 103, 73 | 104, 74 | 105, 75 | 106, 76 | 107, 77 | 108, 78 | 109, 79 | 112, 80 | 113, 81 | 114, 82 | 116, 83 | 118, 84 | 120, 85 | 122, 86 | 124, 87 | 125, 88 | 126, 89 | 127, 90 | 132, 91 | 133, 92 | 134, 93 | 135, 94 | 137, 95 | 140, 96 | 141, 97 | 142, 98 | 144, 99 | 145, 100 | 146, 101 | 150, 102 | 153, 103 | 154, 104 | 155, 105 | 156, 106 | 157, 107 | 158, 108 | 159, 109 | 160, 110 | 161, 111 | 162, 112 | 163, 113 | 164, 114 | 165, 115 | 167, 116 | 170, 117 | 171, 118 | 172, 119 | 173, 120 | 175, 121 | 176, 122 | 178, 123 | 179, 124 | 181, 125 | 185, 126 | 186, 127 | 187, 128 | 188, 129 | 190, 130 | 191, 131 | 193, 132 | 194, 133 | 195, 134 | 196, 135 | 198, 136 | 200, 137 | 202, 138 | 204, 139 | 205, 140 | 206, 141 | 208, 142 | 210, 143 | 211, 144 | 213, 145 | 214, 146 | 215, 147 | 218, 148 | 219, 149 | 220, 150 | 221, 151 | 222, 152 | 223, 153 | 224, 154 | 225, 155 | 229, 156 | 230, 157 | 231, 158 | 233, 159 | 235, 160 | 236, 161 | 237, 162 | 238, 163 | 239, 164 | 240, 165 | 241, 166 | 242, 167 | 243, 168 | 245, 169 | 246, 170 | 247, 171 | 249, 172 | 250, 173 | 251, 174 | 252, 175 | 253, 176 | 255, 177 | 258, 178 | 259, 179 | 261, 180 | 262, 181 | 263, 182 | 264, 183 | 266, 184 | 267, 185 | 268, 186 | 270, 187 | 271, 188 | 272, 189 | 276, 190 | 278, 191 | 279, 192 | 281, 193 | 282, 194 | 283, 195 | 284, 196 | 285, 197 | 289, 198 | 293, 199 | 294, 200 | 295, 201 | 298, 202 | 299, 203 | 300, 204 | 301, 205 | 302, 206 | 303, 207 | 304, 208 | 306, 209 | 308, 210 | 309, 211 | 310, 212 | 311, 213 | 312, 214 | 313, 215 | 315, 216 | 316, 217 | 317, 218 | 318, 219 | 319, 220 | 320, 221 | 322, 222 | 325, 223 | 327, 224 | 329, 225 | 330, 226 | 331, 227 | 332, 228 | 333, 229 | 334, 230 | 336, 231 | 337, 232 | 338, 233 | 339, 234 | 340, 235 | 342, 236 | 343, 237 | 344, 238 | 345, 239 | 346, 240 | 348, 241 | 350, 242 | 351, 243 | 352, 244 | 353, 245 | 354, 246 | 355, 247 | 356, 248 | 357, 249 | 358, 250 | 360, 251 | 361, 252 | 362, 253 | 363, 254 | 364, 255 | 365, 256 | 366, 257 | 367, 258 | 369, 259 | 372, 260 | 374, 261 | 375, 262 | 378, 263 | 379, 264 | 380, 265 | 382, 266 | 384, 267 | 385, 268 | 386, 269 | 389, 270 | 390, 271 | 391, 272 | 392, 273 | 395, 274 | 397, 275 | 399, 276 | 400, 277 | 401, 278 | 402, 279 | 403, 280 | 404, 281 | 405, 282 | 406, 283 | 407, 284 | 408, 285 | 409, 286 | 412, 287 | 413, 288 | 414, 289 | 415, 290 | 416, 291 | 417, 292 | 418, 293 | 419, 294 | 420, 295 | 422, 296 | 424, 297 | 425, 298 | 426, 299 | 427, 300 | 428, 301 | 432, 302 | 433, 303 | 434, 304 | 435, 305 | 436, 306 | 437, 307 | 439, 308 | 440, 309 | 441, 310 | 443, 311 | 447, 312 | 449, 313 | 451, 314 | 452, 315 | 453, 316 | 454, 317 | 455, 318 | 456, 319 | 457, 320 | 458, 321 | 460, 322 | 462, 323 | 463, 324 | 465, 325 | 466, 326 | 467, 327 | 468, 328 | 470, 329 | 471, 330 | 473, 331 | 474, 332 | 475, 333 | 476, 334 | 477, 335 | 478, 336 | 479, 337 | 480, 338 | 481, 339 | 482, 340 | 483, 341 | 484, 342 | 485, 343 | 487, 344 | 489, 345 | 490, 346 | 491, 347 | 492, 348 | 493, 349 | 494, 350 | 495, 351 | 496, 352 | 497, 353 | 498, 354 | 499, 355 | 500, 356 | 501, 357 | 502, 358 | 503, 359 | 504, 360 | 505, 361 | 506, 362 | 508, 363 | 509, 364 | 511, 365 | 513, 366 | 514, 367 | 516, 368 | 517, 369 | 518, 370 | 519, 371 | 520, 372 | 521, 373 | 522, 374 | 523, 375 | 524, 376 | 526, 377 | 527, 378 | 529, 379 | 530, 380 | 531, 381 | 532, 382 | 533, 383 | 534, 384 | 535, 385 | 536, 386 | 538, 387 | 540, 388 | 542, 389 | 545, 390 | 546, 391 | 548, 392 | 549, 393 | 553, 394 | 556, 395 | 557, 396 | 558, 397 | 563, 398 | 564, 399 | 566, 400 | 567, 401 | 568, 402 | 570, 403 | 571, 404 | 572, 405 | 573, 406 | 576, 407 | 577, 408 | 578, 409 | 579, 410 | 580, 411 | 582, 412 | 583, 413 | 584, 414 | 585, 415 | 586, 416 | 587, 417 | 588, 418 | 590, 419 | 591, 420 | 592, 421 | 593, 422 | 595, 423 | 596, 424 | 597, 425 | 602, 426 | 603, 427 | 604, 428 | 605, 429 | 608, 430 | 609, 431 | 612, 432 | 613, 433 | 614, 434 | 615, 435 | 616, 436 | 618, 437 | 619, 438 | 621, 439 | 622, 440 | 623, 441 | 624, 442 | 626, 443 | 627, 444 | 628, 445 | 634, 446 | 636, 447 | 638, 448 | 642, 449 | 643, 450 | 644, 451 | 645, 452 | 646, 453 | 647, 454 | 648, 455 | 649, 456 | 651, 457 | 652, 458 | 654, 459 | 657, 460 | 658, 461 | 661, 462 | 663, 463 | 664, 464 | 665, 465 | 666, 466 | 667, 467 | 668, 468 | 669, 469 | 670, 470 | 671, 471 | 672, 472 | 674, 473 | 676, 474 | 678, 475 | 679, 476 | 680, 477 | 681, 478 | 682, 479 | 683, 480 | 685, 481 | 686, 482 | 687, 483 | 689, 484 | 690, 485 | 691, 486 | 692, 487 | 693, 488 | 694, 489 | 695, 490 | 696, 491 | 698, 492 | 700, 493 | 701, 494 | 704, 495 | 708, 496 | 710, 497 | 712, 498 | 714, 499 | 715, 500 | 716, 501 | 718, 502 | 719, 503 | 720, 504 | 721, 505 | 724, 506 | 725, 507 | 726, 508 | 727, 509 | 728, 510 | 729, 511 | 731, 512 | 732, 513 | 733, 514 | 734, 515 | 735, 516 | 736, 517 | 738, 518 | 739, 519 | 740, 520 | 741, 521 | 742, 522 | 743, 523 | 744, 524 | 745, 525 | 746, 526 | 748, 527 | 749 528 | ], 529 | "test": [ 530 | 9, 531 | 11, 532 | 19, 533 | 23, 534 | 24, 535 | 25, 536 | 28, 537 | 32, 538 | 33, 539 | 36, 540 | 41, 541 | 42, 542 | 43, 543 | 44, 544 | 47, 545 | 53, 546 | 57, 547 | 58, 548 | 61, 549 | 67, 550 | 70, 551 | 72, 552 | 73, 553 | 80, 554 | 82, 555 | 84, 556 | 86, 557 | 87, 558 | 88, 559 | 91, 560 | 94, 561 | 95, 562 | 98, 563 | 99, 564 | 110, 565 | 111, 566 | 115, 567 | 117, 568 | 119, 569 | 121, 570 | 123, 571 | 128, 572 | 129, 573 | 130, 574 | 131, 575 | 136, 576 | 138, 577 | 139, 578 | 143, 579 | 147, 580 | 148, 581 | 149, 582 | 151, 583 | 152, 584 | 166, 585 | 168, 586 | 169, 587 | 174, 588 | 177, 589 | 180, 590 | 182, 591 | 183, 592 | 184, 593 | 189, 594 | 192, 595 | 197, 596 | 199, 597 | 201, 598 | 203, 599 | 207, 600 | 209, 601 | 212, 602 | 216, 603 | 217, 604 | 226, 605 | 227, 606 | 228, 607 | 232, 608 | 234, 609 | 244, 610 | 248, 611 | 254, 612 | 256, 613 | 257, 614 | 260, 615 | 265, 616 | 269, 617 | 273, 618 | 274, 619 | 275, 620 | 277, 621 | 280, 622 | 286, 623 | 287, 624 | 288, 625 | 290, 626 | 291, 627 | 292, 628 | 296, 629 | 297, 630 | 305, 631 | 307, 632 | 314, 633 | 321, 634 | 323, 635 | 324, 636 | 326, 637 | 328, 638 | 335, 639 | 341, 640 | 347, 641 | 349, 642 | 359, 643 | 368, 644 | 370, 645 | 371, 646 | 373, 647 | 376, 648 | 377, 649 | 381, 650 | 383, 651 | 387, 652 | 388, 653 | 393, 654 | 394, 655 | 396, 656 | 398, 657 | 410, 658 | 411, 659 | 421, 660 | 423, 661 | 429, 662 | 430, 663 | 431, 664 | 438, 665 | 442, 666 | 444, 667 | 445, 668 | 446, 669 | 448, 670 | 450, 671 | 459, 672 | 461, 673 | 464, 674 | 469, 675 | 472, 676 | 486, 677 | 488, 678 | 507, 679 | 510, 680 | 512, 681 | 515, 682 | 525, 683 | 528, 684 | 537, 685 | 539, 686 | 541, 687 | 543, 688 | 544, 689 | 547, 690 | 550, 691 | 551, 692 | 552, 693 | 554, 694 | 555, 695 | 559, 696 | 560, 697 | 561, 698 | 562, 699 | 565, 700 | 569, 701 | 574, 702 | 575, 703 | 581, 704 | 589, 705 | 594, 706 | 598, 707 | 599, 708 | 600, 709 | 601, 710 | 606, 711 | 607, 712 | 610, 713 | 611, 714 | 617, 715 | 620, 716 | 625, 717 | 629, 718 | 630, 719 | 631, 720 | 632, 721 | 633, 722 | 635, 723 | 637, 724 | 639, 725 | 640, 726 | 641, 727 | 650, 728 | 653, 729 | 655, 730 | 656, 731 | 659, 732 | 660, 733 | 662, 734 | 673, 735 | 675, 736 | 677, 737 | 684, 738 | 688, 739 | 697, 740 | 699, 741 | 702, 742 | 703, 743 | 705, 744 | 706, 745 | 707, 746 | 709, 747 | 711, 748 | 713, 749 | 717, 750 | 722, 751 | 723, 752 | 730, 753 | 737, 754 | 747 755 | ] 756 | } -------------------------------------------------------------------------------- /notebooks/.gitignore: -------------------------------------------------------------------------------- 1 | figures/ -------------------------------------------------------------------------------- /notebooks/example_load_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e91d42f0", 6 | "metadata": {}, 7 | "source": [ 8 | "# OpenAI Experimental Results\n", 9 | "\n", 10 | "Load results for experiments in original paper (make sure to download result files)." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "id": "af4e76b3", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from data.completion_dataset import CompletionIdentifier, CompletionDataset\n", 21 | "from data.split import load_train_test_split\n", 22 | "from evaluation.evaluator import Evaluator\n", 23 | "from evaluation.summary import summarize_evaluation" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 9, 29 | "id": "5033babe", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "base_model = \"curie\"\n", 34 | "dataset_key = \"multiarith\"\n", 35 | "train_key = \"zs_cot\"\n", 36 | "e = None\n", 37 | "\n", 38 | "ci = CompletionIdentifier(base_model, \"ft_cot\", dataset_key, train_key, e)\n", 39 | "cd = CompletionDataset.load(ci)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 10, 45 | "id": "284e9156", 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "data": { 50 | "text/plain": [ 51 | "{'accuracy': 0.3333333333333333,\n", 52 | " 'contains_answer': 0.3333333333333333,\n", 53 | " 'correct_format': 1.0,\n", 54 | " 'complete': 1.0}" 55 | ] 56 | }, 57 | "execution_count": 10, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "train, test = load_train_test_split(dataset_key)\n", 64 | "evaluation = Evaluator.evaluate_completion_dataset(cd, test,)\n", 65 | "summarize_evaluation(evaluation)" 66 | ] 67 | } 68 | ], 69 | "metadata": { 70 | "kernelspec": { 71 | "display_name": "Python 3 (ipykernel)", 72 | "language": "python", 73 | "name": "python3" 74 | }, 75 | "language_info": { 76 | "codemirror_mode": { 77 | "name": "ipython", 78 | "version": 3 79 | }, 80 | "file_extension": ".py", 81 | "mimetype": "text/x-python", 82 | "name": "python", 83 | "nbconvert_exporter": "python", 84 | "pygments_lexer": "ipython3", 85 | "version": "3.8.13" 86 | } 87 | }, 88 | "nbformat": 4, 89 | "nbformat_minor": 5 90 | } 91 | -------------------------------------------------------------------------------- /notebooks/example_oai_finetune_cot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a616c3c1", 6 | "metadata": {}, 7 | "source": [ 8 | "# Run Fine-tune CoT on OpenAI using our `oai` module\n", 9 | "\n", 10 | "This notebook contains code to (1) generate reasoning samples from teacher models (e.g., GPT-3 175B `text-davinci-002`), (2) fine-tune student models (e.g., GPT-3 0.3B `ada`) and (3) generate and evaluate samples from fine-tuned student models.\n", 11 | "\n", 12 | "- To run from scratch, first download and save original benchmark data (see README).\n", 13 | "- To use existing teacher-generated samples, first download and save original benchmark data and teacher completion data (see README). Then, replace the completion_key `zs_cot_test` with `zs_cot` in the code below." 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "id": "a10a6ccc", 19 | "metadata": {}, 20 | "source": [ 21 | "### TODO: Set OpenAI Key\n", 22 | "\n", 23 | "Create an account on OpenAI and retrieve your API key. Experiments will incurs fees on your OpenAI account." 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "id": "2587ed99", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import openai\n", 34 | "openai.api_key = \"\"" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "86fa4cf8", 40 | "metadata": {}, 41 | "source": [ 42 | "### Imports and Parameters" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "id": "4a11f958", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "from data.completion_dataset import CompletionMetadata, CompletionDataset\n", 53 | "from oai.inference import infer_completion_data" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "id": "be7d870e", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "teacher_base_model = \"text-davinci-002\" # GPT-3 (175B)\n", 64 | "base_model = \"ada\" # GPT-3 (0.3B)\n", 65 | "# base_model = \"babbage\" # GPT-3 (1.3B)\n", 66 | "# base_model = \"curie\" # GPT-3 (6.7B)\n", 67 | "dataset_key = \"date_understanding\"" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "5043037d", 73 | "metadata": {}, 74 | "source": [ 75 | "## Infer teacher completions using OpenAI (generate CompletionDataset)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "id": "3fa53308", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# Note, completion_key identifies the method used to generate completions\n", 86 | "# Note, prediction_template selects the prediction template from those pre-defined in\n", 87 | "# `oai.data.format.Formatter`.\n", 88 | "completion_metadata = CompletionMetadata(base_model=teacher_base_model, completion_key=\"zs_cot_test\",\n", 89 | " dataset_key=dataset_key, prediction_template=\"zs_cot\")" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 5, 95 | "id": "16d1d824", 96 | "metadata": { 97 | "scrolled": true 98 | }, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "Loaded 369 samples from:\n", 105 | "/Users/itsnamgyu/code/temp/reasoning-teacher/saved/completion_data/B_text-davinci-002__C_zs_cot_test/D_date_understanding.json\n", 106 | "All 369 samples have been completed.\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "# Run Zero-shot-CoT step 1 (rationale generation)\n", 112 | "# Note, sample_indices=None means we want to infer on all samples\n", 113 | "completion_dataset = infer_completion_data(completion_metadata, zs_cot_step=1,\n", 114 | " sample_indices=None, augs=1, temperature=0,\n", 115 | " max_tokens=128)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 6, 121 | "id": "042067fa", 122 | "metadata": { 123 | "scrolled": true 124 | }, 125 | "outputs": [ 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | "Loaded 369 samples from:\n", 131 | "/Users/itsnamgyu/code/temp/reasoning-teacher/saved/completion_data/B_text-davinci-002__C_zs_cot_test/D_date_understanding.json\n", 132 | "All 369 samples have been completed.\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "# Run Zero-shot-CoT step 2 (answer)\n", 138 | "completion_dataset = infer_completion_data(completion_metadata, zs_cot_step=2,\n", 139 | " sample_indices=None, augs=1, temperature=0,\n", 140 | " max_tokens=128)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "6b19e754", 146 | "metadata": {}, 147 | "source": [ 148 | "## Load CompletionDataset and evaluate test set" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 7, 154 | "id": "995c950d", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "from data.completion_dataset import CompletionIdentifier\n", 159 | "from data.split import load_train_test_split \n", 160 | "from evaluation.evaluator import Evaluator\n", 161 | "from evaluation.summary import summarize_evaluation " 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 8, 167 | "id": "5cf87485", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "completion_identifier = CompletionIdentifier(teacher_base_model, \"zs_cot_test\", dataset_key)\n", 172 | "completion_dataset = CompletionDataset.load(completion_identifier)\n", 173 | "# Note, completion_metadata can be used instead of completion_identifier such as below\n", 174 | "# completion_dataset = CompletionDataset.load(completion_metadata)\n", 175 | "train, test = load_train_test_split(dataset_key)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 9, 181 | "id": "c80ce146", 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "evaluator = Evaluator.for_completion_dataset(completion_dataset)\n", 186 | "evaluation = evaluator.evaluate_completion_dataset(completion_dataset, test)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 10, 192 | "id": "c36c1760", 193 | "metadata": { 194 | "scrolled": true 195 | }, 196 | "outputs": [ 197 | { 198 | "data": { 199 | "text/html": [ 200 | "
\n", 201 | "\n", 214 | "\n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | "
sample_indexcompletion_indexcorrectcontains_answercorrect_formatcomplete
000TrueTrueTrueTrue
190TrueTrueTrueTrue
2230TrueTrueTrueTrue
3250TrueTrueTrueTrue
4280TrueTrueTrueTrue
\n", 274 | "
" 275 | ], 276 | "text/plain": [ 277 | " sample_index completion_index correct contains_answer correct_format \n", 278 | "0 0 0 True True True \\\n", 279 | "1 9 0 True True True \n", 280 | "2 23 0 True True True \n", 281 | "3 25 0 True True True \n", 282 | "4 28 0 True True True \n", 283 | "\n", 284 | " complete \n", 285 | "0 True \n", 286 | "1 True \n", 287 | "2 True \n", 288 | "3 True \n", 289 | "4 True " 290 | ] 291 | }, 292 | "execution_count": 10, 293 | "metadata": {}, 294 | "output_type": "execute_result" 295 | } 296 | ], 297 | "source": [ 298 | "evaluation.head()" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 11, 304 | "id": "f0d58e46", 305 | "metadata": { 306 | "scrolled": false 307 | }, 308 | "outputs": [ 309 | { 310 | "data": { 311 | "text/plain": [ 312 | "{'accuracy': 0.7477477477477478,\n", 313 | " 'contains_answer': 0.7477477477477478,\n", 314 | " 'correct_format': 1.0,\n", 315 | " 'complete': 1.0}" 316 | ] 317 | }, 318 | "execution_count": 11, 319 | "metadata": {}, 320 | "output_type": "execute_result" 321 | } 322 | ], 323 | "source": [ 324 | "summarize_evaluation(evaluation)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "id": "24187a50", 330 | "metadata": {}, 331 | "source": [ 332 | "## Create fine-tune `File` and `Finetune` using training set" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 12, 338 | "id": "bd27aebf", 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "from oai.finetune import init_finetune, generate_finetune_data_from_completion_dataset\n", 343 | "from oai.utils.api_wrapper import fetch_model_ids" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 13, 349 | "id": "e4edefd6", 350 | "metadata": { 351 | "scrolled": true 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "# Replace \"zs_cot_test\" with \"zs_cot\" to use our teacher-generated completions (see README for how to download).\n", 356 | "completion_identifier = CompletionIdentifier(teacher_base_model, \"zs_cot_test\", dataset_key)\n", 357 | "completion_dataset = CompletionDataset.load(completion_identifier)\n", 358 | "train, test = load_train_test_split(dataset_key)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 14, 364 | "id": "ee88849d", 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "finetune_key = \"zs_cot_test_{}\".format(dataset_key)\n", 369 | "train_key = \"ft_cot_test\"" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 15, 375 | "id": "55825a80", 376 | "metadata": { 377 | "scrolled": false 378 | }, 379 | "outputs": [ 380 | { 381 | "name": "stdout", 382 | "output_type": "stream", 383 | "text": [ 384 | "Saving 171 fine-tuning samples to /Users/itsnamgyu/code/temp/reasoning-teacher/saved/finetune_data/P_openai/F_zs_cot_test_date_understanding.jsonl\n" 385 | ] 386 | } 387 | ], 388 | "source": [ 389 | "# Note, finetune_key is a unique identifier for the finetuning data and should contain the source dataset\n", 390 | "generate_finetune_data_from_completion_dataset(completion_dataset=completion_dataset,\n", 391 | " prediction_template=\"ft_cot_token\",\n", 392 | " finetune_key=finetune_key,\n", 393 | " sample_indices=train,\n", 394 | " only_correct=True, # default\n", 395 | " )" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 16, 401 | "id": "c98c87e3", 402 | "metadata": { 403 | "scrolled": true 404 | }, 405 | "outputs": [ 406 | { 407 | "name": "stdout", 408 | "output_type": "stream", 409 | "text": [ 410 | "{\n", 411 | " \"prompt\": \"Yesterday was April 30, 2021. What is the date one year ago from today in MM/DD/YYYY?\\nWhich choice is true? Answer choices: (A) 05/01/1971, (B) 04/01/2020, (C) 05/15/2020, (D) 05/01/2020, (E) 05/08/2020. ###\",\n", 412 | " \"completion\": \" One year ago from today would be 2020. Today is 2021. 2020 is two years ago. Two years ago from today would be 05/01/2019. --> D END\"\n", 413 | "}\n" 414 | ] 415 | } 416 | ], 417 | "source": [ 418 | "# Inspect finetune data\n", 419 | "import json\n", 420 | "from paths import get_finetune_data_path\n", 421 | "with open(get_finetune_data_path(\"openai\", finetune_key)) as f:\n", 422 | " print(json.dumps(json.loads(f.readline()), indent=4))" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 17, 428 | "id": "03bf8d56", 429 | "metadata": { 430 | "scrolled": false 431 | }, 432 | "outputs": [ 433 | { 434 | "name": "stdout", 435 | "output_type": "stream", 436 | "text": [ 437 | "Warning: OpenAI File `zs_cot_test_date_understanding` already exists (likely already uploaded). Skipping.\n", 438 | "Warning: OpenAI Finetune for `B_ada__D_date_understanding__T_ft_cot_test` already exists. Skipping.\n" 439 | ] 440 | }, 441 | { 442 | "data": { 443 | "text/plain": [ 444 | "'B_ada__D_date_understanding__T_ft_cot_test'" 445 | ] 446 | }, 447 | "execution_count": 17, 448 | "metadata": {}, 449 | "output_type": "execute_result" 450 | } 451 | ], 452 | "source": [ 453 | "# Note, train_key identifies the method used to train the model, i.e., the method used to fine-tune the base model.\n", 454 | "init_finetune(finetune_key, base_model, dataset_key, train_key)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "markdown", 459 | "id": "ec628615", 460 | "metadata": {}, 461 | "source": [ 462 | "### Fetch fine-tuned `Model` id\n", 463 | "\n", 464 | "You need to keep calling this function to check if your `Finetune` is finished. Fine-tuning typically take about 5 minutes to 1 hour." 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 19, 470 | "id": "d83504c2", 471 | "metadata": { 472 | "scrolled": false 473 | }, 474 | "outputs": [ 475 | { 476 | "name": "stdout", 477 | "output_type": "stream", 478 | "text": [ 479 | "No model ids to fetch\n" 480 | ] 481 | }, 482 | { 483 | "data": { 484 | "text/plain": [ 485 | "True" 486 | ] 487 | }, 488 | "execution_count": 19, 489 | "metadata": {}, 490 | "output_type": "execute_result" 491 | } 492 | ], 493 | "source": [ 494 | "fetch_model_ids()" 495 | ] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "id": "ea2ef518", 500 | "metadata": {}, 501 | "source": [ 502 | "### Access OpenAI metadata\n", 503 | "\n", 504 | "We use metadata files to map our identifiers (keys) to the identifier (ids) used by OpenAI objects.\n", 505 | "These can be accessed manually, as follows." 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 20, 511 | "id": "01d9f0b2", 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "from oai.utils.metadata import get_file_id, get_finetune_id, get_model_id, get_model_key" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": 21, 521 | "id": "48975b97", 522 | "metadata": {}, 523 | "outputs": [], 524 | "source": [ 525 | "# Note that `base_model`, `dataset_key`, `train_key` are joined together to form a `model_key` which\n", 526 | "# identifies fine-tuned models. There is a one-to-one-to-one mapping between a model_key, Finetune object,\n", 527 | "# and Model object.\n", 528 | "model_key = get_model_key(base_model, dataset_key, train_key)" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": 22, 534 | "id": "e1a46489", 535 | "metadata": {}, 536 | "outputs": [ 537 | { 538 | "data": { 539 | "text/plain": [ 540 | "'file-3lwlV7lJRebTr0JTniuZ7lCX'" 541 | ] 542 | }, 543 | "execution_count": 22, 544 | "metadata": {}, 545 | "output_type": "execute_result" 546 | } 547 | ], 548 | "source": [ 549 | "# Note that our `finetune_key` identifies the fine-tuning \"data\", therefore is mapped to a File object\n", 550 | "# rather than a Finetune object.\n", 551 | "get_file_id(finetune_key)" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": 23, 557 | "id": "95162d8c", 558 | "metadata": {}, 559 | "outputs": [ 560 | { 561 | "data": { 562 | "text/plain": [ 563 | "'ft-ord6Qs8vmXQI8VjVWrfNTrTs'" 564 | ] 565 | }, 566 | "execution_count": 23, 567 | "metadata": {}, 568 | "output_type": "execute_result" 569 | } 570 | ], 571 | "source": [ 572 | "get_finetune_id(model_key)" 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": 24, 578 | "id": "71570d68", 579 | "metadata": { 580 | "scrolled": true 581 | }, 582 | "outputs": [ 583 | { 584 | "data": { 585 | "text/plain": [ 586 | "'ada:ft-namgyu-ho-2023-06-11-04-37-04'" 587 | ] 588 | }, 589 | "execution_count": 24, 590 | "metadata": {}, 591 | "output_type": "execute_result" 592 | } 593 | ], 594 | "source": [ 595 | "get_model_id(model_key) # fetched by `fetch_model_ids()`" 596 | ] 597 | }, 598 | { 599 | "cell_type": "markdown", 600 | "id": "17c0460e", 601 | "metadata": {}, 602 | "source": [ 603 | "## Infer student completions\n", 604 | "\n", 605 | "We only infer test set samples for evaluation." 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": 25, 611 | "id": "df8e3ede", 612 | "metadata": {}, 613 | "outputs": [], 614 | "source": [ 615 | "# Note, completion_key and train_key are both \"ft_cot_test\". Recall that completion_key refers to\n", 616 | "# the method used to generate completions by the student model, and train_key refers to the method\n", 617 | "# used to train the student model.\n", 618 | "completion_metadata = CompletionMetadata(base_model=base_model, completion_key=\"ft_cot_test\",\n", 619 | " dataset_key=dataset_key, finetune_key=finetune_key,\n", 620 | " prediction_template=\"ft_cot_token\",\n", 621 | " train_key=train_key, epoch=None)\n", 622 | "train, test = load_train_test_split(dataset_key)" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 26, 628 | "id": "aa1a4a0e", 629 | "metadata": { 630 | "scrolled": false 631 | }, 632 | "outputs": [ 633 | { 634 | "name": "stdout", 635 | "output_type": "stream", 636 | "text": [ 637 | "Initializing new CompletionDataset at:\n", 638 | "/Users/itsnamgyu/code/temp/reasoning-teacher/saved/completion_data/B_ada__C_ft_cot_test/D_date_understanding__T_ft_cot_test.json\n", 639 | "Inferring completions for 111 remaining samples (total=111)\n" 640 | ] 641 | }, 642 | { 643 | "name": "stderr", 644 | "output_type": "stream", 645 | "text": [ 646 | "Inferring completions via OpenAI: 100%|███████████████████████████| 111/111 [00:12<00:00, 8.72it/s]\n" 647 | ] 648 | } 649 | ], 650 | "source": [ 651 | "# Note, `infer_completion_data` will find our new student model (that we fetched above) by using\n", 652 | "# `base_model`, `dataset_key`, and `train_key` which is specified in `completion_metadata`.\n", 653 | "completion_dataset = infer_completion_data(completion_metadata, zs_cot_step=None,\n", 654 | " sample_indices=test, augs=1, temperature=0,\n", 655 | " max_tokens=1024) # note, we use 1024 tokens for student inference" 656 | ] 657 | }, 658 | { 659 | "cell_type": "markdown", 660 | "id": "75c623f5", 661 | "metadata": {}, 662 | "source": [ 663 | "## Evaluate student completions" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": 27, 669 | "id": "afc851ee", 670 | "metadata": {}, 671 | "outputs": [], 672 | "source": [ 673 | "completion_identifier = CompletionIdentifier(base_model, completion_key=\"ft_cot_test\", dataset_key=dataset_key,\n", 674 | " train_key=\"ft_cot_test\")\n", 675 | "completion_dataset = CompletionDataset.load(completion_identifier)\n", 676 | "train, test = load_train_test_split(dataset_key)" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": 28, 682 | "id": "21eac841", 683 | "metadata": { 684 | "scrolled": true 685 | }, 686 | "outputs": [], 687 | "source": [ 688 | "evaluator = Evaluator(dataset_key, \"ft_cot_token\")\n", 689 | "evaluation = evaluator.evaluate_completion_dataset(completion_dataset, test)" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 29, 695 | "id": "a4181fbb", 696 | "metadata": { 697 | "scrolled": false 698 | }, 699 | "outputs": [ 700 | { 701 | "data": { 702 | "text/plain": [ 703 | "{'accuracy': 0.12612612612612611,\n", 704 | " 'contains_answer': 0.12612612612612611,\n", 705 | " 'correct_format': 0.9819819819819819,\n", 706 | " 'complete': 0.9819819819819819}" 707 | ] 708 | }, 709 | "execution_count": 29, 710 | "metadata": {}, 711 | "output_type": "execute_result" 712 | } 713 | ], 714 | "source": [ 715 | "summarize_evaluation(evaluation)" 716 | ] 717 | } 718 | ], 719 | "metadata": { 720 | "kernelspec": { 721 | "display_name": "Python 3 (ipykernel)", 722 | "language": "python", 723 | "name": "python3" 724 | }, 725 | "language_info": { 726 | "codemirror_mode": { 727 | "name": "ipython", 728 | "version": 3 729 | }, 730 | "file_extension": ".py", 731 | "mimetype": "text/x-python", 732 | "name": "python", 733 | "nbconvert_exporter": "python", 734 | "pygments_lexer": "ipython3", 735 | "version": "3.9.12" 736 | } 737 | }, 738 | "nbformat": 4, 739 | "nbformat_minor": 5 740 | } 741 | -------------------------------------------------------------------------------- /notebooks/old/README.txt: -------------------------------------------------------------------------------- 1 | These notebooks were used to run and evaluate all experiments on our ACL initial submission. 2 | The notebooks reference old code modules, so they will not run as is. However, they can be used as a reference for 3 | experimental details (especially for analysis-related experiments). 4 | 5 | "Data Preparation.ipynb" was used to preprocess the original datasets into the json files that go in `data/dataset/` (refer to main README for download links). -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | pandas 3 | numpy 4 | matplotlib 5 | easydict 6 | pytorch-lightning~=1.9 7 | transformers 8 | datasets 9 | sentencepiece 10 | deepspeed 11 | scipy 12 | seaborn 13 | -------------------------------------------------------------------------------- /scripts/custom/example_ft5.sh: -------------------------------------------------------------------------------- 1 | TARGETS=("tracking_shuffled_objects" "date_understanding" "coin_flip" "last_letter_concatenation" "commonsense_qa" "strategy_qa" 2 | "single_eq" "addsub" "multiarith" "svamp" "gsm8k" "aqua") 3 | MODELS=("flan_t5_small" "flan_t5_base" "flan_t5_large" "flan_t5_xl") 4 | DEVICES="0" 5 | 6 | 7 | for MODEL in ${MODELS[@]}; do 8 | for TARGET in ${TARGETS[@]}; do 9 | python custom_train.py --dataset_key $TARGET --model_key $MODEL --train_key "ft_cot" --devices $DEVICES --batch_size 8 --inference_batch_size 32 10 | done 11 | done 12 | -------------------------------------------------------------------------------- /scripts/custom/example_gpt2.sh: -------------------------------------------------------------------------------- 1 | TARGETS=("tracking_shuffled_objects" "date_understanding" "coin_flip" "last_letter_concatenation" "commonsense_qa" "strategy_qa" 2 | "single_eq" "addsub" "multiarith" "svamp" "gsm8k" "aqua") 3 | MODELS=("gpt2" "gpt2_medium" "gpt2_large") 4 | DEVICES="0" 5 | 6 | 7 | for MODEL in ${MODELS[@]}; do 8 | for TARGET in ${TARGETS[@]}; do 9 | python custom_train.py --dataset_key $TARGET --model_key $MODEL --train_key "ft_cot" --devices $DEVICES --batch_size 8 --inference_batch_size 32 10 | done 11 | done 12 | -------------------------------------------------------------------------------- /scripts/custom/example_t5.sh: -------------------------------------------------------------------------------- 1 | TARGETS=("tracking_shuffled_objects" "date_understanding" "coin_flip" "last_letter_concatenation" "commonsense_qa" "strategy_qa" 2 | "single_eq" "addsub" "multiarith" "svamp" "gsm8k" "aqua") 3 | MODELS=("t5_small" "t5_base" "t5_large" "t5_3b") 4 | DEVICES="0" 5 | 6 | for MODEL in ${MODELS[@]}; do 7 | for TARGET in ${TARGETS[@]}; do 8 | python custom_train.py --dataset_key $TARGET --model_key $MODEL --train_key "ft_cot" --devices $DEVICES --batch_size 8 --inference_batch_size 32 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='reasoning_teacher', 5 | version='0.1', 6 | packages=[''], 7 | package_dir={'': 'src'}, 8 | url='', 9 | license='', 10 | author='Namgyu Ho', 11 | author_email='itsnamgyu@kaist.ac.kr', 12 | description='' 13 | ) -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | python setup.py develop -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itsnamgyu/reasoning-teacher/a2ca4d28d3bbabbd77106b76d06885d1a5eac0d9/src/__init__.py -------------------------------------------------------------------------------- /src/custom/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom experiment source code (using GPUs, not OpenAI APIs) 3 | """ -------------------------------------------------------------------------------- /src/custom/data_module.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import pytorch_lightning as pl 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | from custom.utils import save_finetune_data, load_finetune_data, list_of_dicts_to_dict_of_lists 7 | from data.completion_dataset import CompletionDataset, CompletionIdentifier 8 | from data.dataset import DATASET_KEYS, Dataset 9 | from data.format import Formatter 10 | from data.split import load_train_test_split 11 | 12 | SUPPORTED_KEYS = ["zs", "fs_cot", "ft", "ft_cot"] 13 | SUPPORTED_KEYS += ["ft_cot_t70_{}aug".format(aug) for aug in [1, 2, 4, 8, 16, 32, 64]] 14 | SUPPORTED_MODEL_TYPES = ["decoder", "encoder_decoder"] 15 | 16 | 17 | class DataModule(pl.LightningDataModule): 18 | train_dataset: Dataset 19 | test_dataset: Dataset 20 | 21 | def __init__(self, dataset_key: str, preset_key: str, tokenizer, model_type: str, batch_size: int = 32, 22 | inference_batch_size=None, num_workers: int = 8, append_eos=False): 23 | """ 24 | Note that padding is applied manually on the left side when `model_type=decoder`, therefore 25 | `tokenizer.padding_side` is irrelevant. 26 | 27 | - model_type: `decoder` or `encoder_decoder`. Used as `platform_key` for saving fine-tune data. 28 | - append_eos: manually append eos token to the end of the label. 29 | """ 30 | super().__init__() 31 | if dataset_key not in DATASET_KEYS: 32 | raise NotImplementedError("dataset_key={}".format(dataset_key)) 33 | if preset_key not in SUPPORTED_KEYS: 34 | raise NotImplementedError("Not implemented: key={}".format(preset_key)) 35 | if model_type not in SUPPORTED_MODEL_TYPES: 36 | raise NotImplementedError("Not implemented: model_type={}".format(model_type)) 37 | 38 | self.dataset_key = dataset_key 39 | self.preset_key = preset_key 40 | self.tokenizer = tokenizer 41 | self.model_type = model_type 42 | self.batch_size = batch_size 43 | if inference_batch_size is None: 44 | self.inference_batch_size = batch_size 45 | else: 46 | self.inference_batch_size = inference_batch_size 47 | self.num_workers = num_workers 48 | self.append_eos = append_eos 49 | 50 | if self.preset_key == "zs": 51 | self.finetune_key = None 52 | self.prediction_template = "zs" 53 | if self.preset_key == "zs_cot": 54 | self.finetune_key = None 55 | self.prediction_template = "zs_cot" 56 | if self.preset_key == "fs_cot": 57 | self.finetune_key = None 58 | self.prediction_template = "fs_cot" 59 | if self.preset_key == "ft": 60 | self.finetune_key = self.dataset_key 61 | self.prediction_template = "ft_token" 62 | if self.preset_key == "ft_cot": 63 | self.finetune_key = "zs_cot_{}".format(self.dataset_key) 64 | self.prediction_template = "ft_cot_token" 65 | for aug in [1, 2, 4, 8, 16, 32, 64]: 66 | if self.preset_key == "ft_cot_t70_{}aug".format(aug): 67 | self.finetune_key = "zs_cot_t70_{}_{}aug".format(self.dataset_key, aug) 68 | self.prediction_template = "ft_cot_token" 69 | 70 | def save_finetune_data(self, train_data): 71 | save_finetune_data(train_data, platform_key=self.model_type, finetune_key=self.finetune_key) 72 | 73 | def load_finetune_data(self): 74 | return load_finetune_data(platform_key=self.model_type, finetune_key=self.finetune_key) 75 | 76 | def setup(self, stage: str = None): 77 | """ 78 | Compile train/test data. Train data (i.e., finetune_data) is saved to disk according to finetune_key. 79 | """ 80 | if stage == "fit": 81 | train_data = self.load_finetune_data() 82 | if train_data is None: 83 | train_data = self._compile_data(train=True) 84 | self.save_finetune_data(train_data) 85 | dataset = datasets.Dataset.from_dict(train_data) 86 | dataset = dataset.map(self.tokenize, batched=True, batch_size=len(dataset)) 87 | if self.model_type == "decoder": 88 | dataset.set_format(type="torch", columns=["sample_index", "input_ids", "attention_mask", "labels"]) 89 | elif self.model_type == "encoder_decoder": 90 | dataset.set_format(type="torch", 91 | columns=["sample_index", "input_ids", "attention_mask", "decoder_attention_mask", 92 | "labels"]) 93 | else: 94 | raise NotImplementedError(self.model_type) 95 | self.train_dataset = dataset 96 | 97 | test_data = self._compile_data(train=False) 98 | dataset = datasets.Dataset.from_dict(test_data) 99 | dataset = dataset.map(self.tokenize, batched=True, batch_size=len(dataset)) 100 | dataset.set_format(type="torch", columns=["sample_index", "input_ids", "attention_mask"]) 101 | self.test_dataset = dataset 102 | 103 | def tokenize(self, example) -> dict: 104 | if self.append_eos and "label" in example: 105 | for i in range(len(example["label"])): 106 | example["label"][i] += self.tokenizer.eos_token 107 | 108 | if self.preset_key == "fs_cot": 109 | input_max_length = 768 110 | else: 111 | input_max_length = 512 112 | 113 | if self.model_type == "encoder_decoder": 114 | it = self.tokenizer( 115 | example["input"], 116 | padding="longest", 117 | max_length=input_max_length, 118 | truncation=True, 119 | return_tensors="pt", 120 | ) 121 | input_ids = it["input_ids"] 122 | attention_mask = it["attention_mask"] 123 | result = { 124 | "input_ids": input_ids, 125 | "attention_mask": attention_mask 126 | } 127 | if "label" in example: 128 | lt = self.tokenizer( 129 | example["label"], 130 | padding="longest", 131 | max_length=512, 132 | truncation=True, 133 | return_tensors="pt", 134 | ) 135 | label_ids = lt["input_ids"] 136 | decoder_attention_mask = lt["attention_mask"] 137 | label_ids[~decoder_attention_mask.bool()] = -100 138 | result.update({ 139 | "decoder_attention_mask": decoder_attention_mask, 140 | "labels": label_ids, 141 | }) 142 | return result 143 | elif self.model_type == "decoder": 144 | # Tokenize and apply left side padding manually 145 | 146 | # Tokenize in vanilla Python list form 147 | it = self.tokenizer( 148 | example["input"], 149 | max_length=input_max_length, 150 | truncation=True 151 | ) 152 | iids = it["input_ids"] 153 | if "label" in example: 154 | lids = self.tokenizer( 155 | example["label"], 156 | max_length=512, 157 | truncation=True 158 | )["input_ids"] 159 | else: 160 | lids = [list() for _ in range(len(iids))] 161 | 162 | lengths = [] 163 | input_ids = [] 164 | attention_mask = [] 165 | label_ids = [] 166 | for iid, lid in zip(iids, lids): 167 | lengths.append(len(iid) + len(lid)) 168 | input_ids.append(iid + lid) 169 | attention_mask.append([1] * (len(iid) + len(lid))) 170 | label_ids.append([-100] * len(iid) + lid) 171 | 172 | # Pad full sequences 173 | lengths = torch.tensor(lengths) 174 | pad_lengths = (lengths.max() - lengths).tolist() 175 | for i, l in enumerate(pad_lengths): 176 | # Apply left side padding 177 | # Why? https://github.com/huggingface/transformers/issues/3021#issuecomment-1231526631 178 | input_ids[i] = [self.tokenizer.pad_token_id] * l + input_ids[i] 179 | attention_mask[i] = [0] * l + attention_mask[i] 180 | label_ids[i] = [-100] * l + label_ids[i] 181 | return { 182 | "input_ids": torch.tensor(input_ids, dtype=torch.long), 183 | "attention_mask": torch.tensor(attention_mask, dtype=torch.long), 184 | "labels": torch.tensor(label_ids, dtype=torch.long), 185 | } 186 | else: 187 | raise NotImplementedError(self.model_type) 188 | 189 | def train_dataloader(self): 190 | return DataLoader( 191 | self.train_dataset, 192 | batch_size=self.batch_size, 193 | num_workers=self.num_workers, 194 | shuffle=True, 195 | ) 196 | 197 | def val_dataloader(self): 198 | return DataLoader( 199 | self.test_dataset, 200 | batch_size=self.inference_batch_size, 201 | num_workers=self.num_workers, 202 | shuffle=False, 203 | ) 204 | 205 | def test_dataloader(self): 206 | return DataLoader( 207 | self.test_dataset, 208 | batch_size=self.inference_batch_size, 209 | num_workers=self.num_workers, 210 | shuffle=False, 211 | ) 212 | 213 | def _compile_data(self, train=False): 214 | train_indices, test_indices = load_train_test_split(self.dataset_key) 215 | sample_indices = train_indices if train else test_indices 216 | 217 | if self.preset_key == "fs_cot" and train: 218 | raise NotImplementedError("fs_cot is not implemented for training.") 219 | 220 | # Get list of samples based on preset key. 221 | samples = None 222 | if self.preset_key in ["zs", "fs_cot", "ft"] or not train: 223 | dataset = Dataset.load(self.dataset_key) 224 | samples = dataset.select_samples(sample_indices) 225 | elif "ft_cot" in self.preset_key: 226 | if self.preset_key == "ft_cot": 227 | completion_identifier = CompletionIdentifier("text-davinci-002", "zs_cot", self.dataset_key) 228 | completion_indices = [0] 229 | else: 230 | for aug in [1, 2, 4, 8, 16, 32, 64]: 231 | if self.preset_key == "ft_cot_t70_{}aug".format(aug): 232 | completion_identifier = CompletionIdentifier("text-davinci-002", "zs_cot_t70", self.dataset_key) 233 | completion_indices = list(range(aug)) 234 | break 235 | else: 236 | raise NotImplementedError(self.preset_key) 237 | completion_dataset = CompletionDataset.load(completion_identifier) 238 | samples = completion_dataset.select_samples(sample_indices, completion_indices, only_correct=True) 239 | 240 | if samples is None: 241 | raise NotImplementedError(self.preset_key) 242 | 243 | # Format samples for model input/output. 244 | formatter = Formatter(self.model_type, self.prediction_template, dataset_key=self.dataset_key) 245 | formatted = formatter.format_samples(samples, include_labels=train) 246 | data = list_of_dicts_to_dict_of_lists(formatted) 247 | 248 | return data 249 | -------------------------------------------------------------------------------- /src/custom/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | LightningModule for training encoder OR encoder_decoder models which provides: 3 | - Saving intermediate validation predictions as CompletionDataset 4 | - Logging intermediate validation metrics (from the CompletionDataset) 5 | """ 6 | import copy 7 | import json 8 | import logging 9 | from typing import List, Dict 10 | 11 | import pytorch_lightning as pl 12 | import torch 13 | from deepspeed.ops.adam import DeepSpeedCPUAdam 14 | from transformers import PreTrainedTokenizerBase 15 | 16 | from data.completion_dataset import CompletionDataset, CompletionMetadata 17 | from data.dataset import Dataset 18 | from evaluation.evaluator import Evaluator 19 | from evaluation.summary import summarize_evaluation 20 | 21 | 22 | class Model(pl.LightningModule): 23 | validation_predictions: Dict 24 | 25 | def __init__(self, model, tokenizer: PreTrainedTokenizerBase, model_type: str, use_cpu_offload=False, 26 | completion_metadata: CompletionMetadata = None, lr=3e-4, truncate_early=True, max_length=1024): 27 | """ 28 | - completion_metadata: metaddata used to save completions. If None, completions are not saved. 29 | `epoch_N` is appended to the `train_key` when saving intermediate validation completions. 30 | """ 31 | super().__init__() 32 | self.model = model 33 | self.tokenizer = tokenizer 34 | self.model_type = model_type 35 | self.use_cpu_offload = use_cpu_offload 36 | self.completion_metadata = completion_metadata 37 | self.lr = lr 38 | self.max_length = max_length 39 | self.truncate_early = truncate_early 40 | 41 | def training_step(self, batch, batch_idx): 42 | kwargs = { 43 | "input_ids": batch["input_ids"], 44 | "attention_mask": batch["attention_mask"], 45 | "labels": batch["labels"], 46 | } 47 | if self.model_type == "encoder_decoder": 48 | kwargs["decoder_attention_mask"] = batch["decoder_attention_mask"] 49 | return self.model(**kwargs)["loss"] 50 | 51 | def validation_step(self, batch, batch_idx) -> Dict[str, torch.Tensor]: 52 | """ 53 | Returns outputs in dictionary format, since it's the only way that seems to work with `all_gather` 54 | """ 55 | if self.current_epoch < 2 and self.truncate_early: 56 | max_length = 256 57 | else: 58 | max_length = self.max_length 59 | 60 | if self.model_type == "encoder_decoder": 61 | output = self.model.generate(batch["input_ids"], max_length=max_length).detach() 62 | elif self.model_type == "decoder": 63 | output = self.model.generate(batch["input_ids"], max_length=max_length, 64 | pad_token_id=self.tokenizer.pad_token_id, 65 | eos_token_id=self.tokenizer.eos_token_id).detach() 66 | else: 67 | raise NotImplementedError("model_type='{}' not supported".format(self.model_type)) 68 | 69 | return { 70 | "sample_index": batch["sample_index"], 71 | "input": batch["input_ids"], 72 | "output": output, 73 | } 74 | 75 | def validation_epoch_end(self, outputs: List[Dict]) -> None: 76 | """ 77 | Gather outputs from all GPUs and save validation predictions as a CompletionDataset and 78 | log validation metrics. 79 | 80 | Note, `all_gather` *concatenates* tensors from all GPUs along the first dimension. 81 | """ 82 | # Determine total sample count and local max input/output length 83 | local_max_output_length = 0 84 | local_max_input_length = 0 85 | total_samples = 0 86 | for batch in outputs: 87 | local_max_input_length = max(local_max_input_length, batch["input"].shape[-1]) 88 | local_max_output_length = max(local_max_output_length, batch["output"].shape[-1]) 89 | total_samples += batch["sample_index"].shape[0] 90 | 91 | # Determine global max input/output length 92 | max_input_length = self.all_gather(torch.tensor(local_max_input_length, dtype=torch.long)).max() 93 | max_output_length = self.all_gather(torch.tensor(local_max_output_length, dtype=torch.long)).max() 94 | 95 | # Create local padded tensors 96 | local_outputs: dict = { 97 | "sample_index": torch.ones((total_samples,), dtype=torch.long) * self.tokenizer.pad_token_id, 98 | "input": torch.ones((total_samples, max_input_length), dtype=torch.long) * self.tokenizer.pad_token_id, 99 | "output": torch.ones((total_samples, max_output_length), dtype=torch.long) * self.tokenizer.pad_token_id, 100 | } 101 | 102 | # Populate local tensors 103 | start_index = 0 104 | for i, batch in enumerate(outputs): 105 | batch_size = batch["sample_index"].shape[0] 106 | end_index = start_index + batch_size 107 | local_outputs["sample_index"][start_index:end_index] = batch["sample_index"] 108 | input_width = batch["input"].shape[-1] 109 | output_width = batch["output"].shape[-1] 110 | if self.model_type == "encoder_decoder": 111 | local_outputs["input"][start_index:end_index, :input_width] = batch["input"] 112 | local_outputs["output"][start_index:end_index, :output_width] = batch["output"] 113 | elif self.model_type == "decoder": 114 | output_only_width = output_width - input_width 115 | local_outputs["input"][start_index:end_index, :input_width] = batch["input"] 116 | local_outputs["output"][start_index:end_index, :output_only_width] = batch["output"][:, input_width:] 117 | else: 118 | raise NotImplementedError("model_type='{}' not supported".format(self.model_type)) 119 | 120 | start_index = end_index 121 | 122 | global_outputs = self.all_gather(local_outputs) 123 | if self.global_rank == 0: 124 | if global_outputs["sample_index"].dim() == 2: # world_size > 1 125 | global_outputs["sample_index"] = global_outputs["sample_index"].flatten(start_dim=0, end_dim=1) 126 | global_outputs["output"] = global_outputs["output"].flatten(start_dim=0, end_dim=1) 127 | global_outputs["input"] = global_outputs["input"].flatten(start_dim=0, end_dim=1) 128 | 129 | final_output = { 130 | "sample_index": global_outputs["sample_index"].tolist(), 131 | "input": self.tokenizer.batch_decode(global_outputs["input"], skip_special_tokens=True), 132 | "output": self.tokenizer.batch_decode(global_outputs["output"], skip_special_tokens=True), 133 | } 134 | 135 | if self.completion_metadata is not None: 136 | # Save outputs as CompletionDataset 137 | cd = self._generate_completion_dataset(self.completion_metadata, final_output, epoch=self.current_epoch) 138 | cd.save() 139 | 140 | # Log validation examples 141 | examples = [] 142 | for i in cd.indices[:5]: 143 | examples.append(cd[i]) 144 | logging.info("VALIDATION_EXAMPLES".center(80, "-")) 145 | logging.info(json.dumps(examples, indent=4)) 146 | 147 | # Log metrics 148 | evaluation = Evaluator.evaluate_completion_dataset(cd) 149 | summary = summarize_evaluation(evaluation) 150 | if summary: 151 | for key, value in summary.items(): 152 | if key == "accuracy": 153 | self.log(key, value, prog_bar=True, logger=True) 154 | else: 155 | self.log(key, value, prog_bar=False, logger=True) 156 | 157 | def configure_optimizers(self): 158 | if self.use_cpu_offload: 159 | optimizer = DeepSpeedCPUAdam(self.parameters(), lr=self.lr) 160 | else: 161 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 162 | return optimizer 163 | 164 | @staticmethod 165 | def _generate_completion_dataset(completion_metadata, output: Dict[str, List], epoch=None, 166 | completions_per_sample=1) -> CompletionDataset: 167 | """ 168 | Initialize and populate a CompletionDataset from model output. 169 | 170 | - output: { 171 | sample_index: List[int], 172 | input: List[str], 173 | output: List[str], 174 | } 175 | - completions_per_sample: limit the number of completions used per sample. This is useful when model output 176 | is obtained from distributed inference, where some samples may be duplicated to match batch sizes. Will use 177 | all completions if None. Existing completions count towards the limit. 178 | """ 179 | if completions_per_sample is not None and completions_per_sample < 1: 180 | raise ValueError("completions_per_sample must be at least 1") 181 | 182 | # Add/assign epoch to train key of completion_identifier 183 | completion_metadata = copy.deepcopy(completion_metadata) 184 | if epoch is not None: 185 | completion_metadata.epoch = epoch 186 | 187 | # Initialize completion dataset 188 | cd = CompletionDataset.init(completion_metadata) 189 | 190 | # Populate completion dataset with model output 191 | dataset = Dataset.load(cd.dataset_key) 192 | for sample_index, input, output in zip(output["sample_index"], output["input"], output["output"]): 193 | if len(dataset) <= sample_index: 194 | raise KeyError( 195 | "Sample index {} not found in dataset {}".format(sample_index, cd.dataset_key)) 196 | 197 | if sample_index in cd.data: 198 | completions = cd.data[sample_index] 199 | else: 200 | completions = list() 201 | cd.data[sample_index] = completions 202 | 203 | completion_index = len(completions) 204 | if completions_per_sample is None or completion_index < completions_per_sample: 205 | completions.append({ 206 | "sample_index": sample_index, 207 | "completion_index": completion_index, 208 | "question": dataset[sample_index]["question"], 209 | "answer": dataset[sample_index]["answer"], 210 | "prompt": input, 211 | "completion": output, 212 | }) 213 | cd.data[sample_index] = completions 214 | 215 | return cd 216 | -------------------------------------------------------------------------------- /src/custom/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import warnings 4 | from typing import Optional 5 | 6 | from paths import get_finetune_data_path 7 | 8 | 9 | def save_finetune_data(data, platform_key: str, finetune_key: str, strict=True) -> str: 10 | path = get_finetune_data_path(platform_key, finetune_key) 11 | print("Saving finetune data") 12 | print("-" * 80) 13 | print("Path: {}".format(path)) 14 | print("Samples: {}".format(len(data["input"]))) 15 | print("-" * 80) 16 | 17 | if os.path.exists(path): 18 | with open(path) as f: 19 | data_string = json.dumps(data, indent=4) 20 | existing_data_string = f.read() 21 | if data_string != existing_data_string: 22 | message = "Finetune data file already exists but is different at: {}".format(path) 23 | if strict: 24 | raise Exception(message) 25 | else: 26 | warnings.warn(message) 27 | else: 28 | os.makedirs(os.path.dirname(path), exist_ok=True) 29 | with open(path, "w") as f: 30 | json.dump(data, f, indent=4) 31 | 32 | return path 33 | 34 | 35 | def load_finetune_data(platform_key: str, finetune_key: str) -> Optional[dict]: 36 | path = get_finetune_data_path(platform_key, finetune_key) 37 | if os.path.exists(path): 38 | with open(path) as f: 39 | data = json.load(f) 40 | return data 41 | else: 42 | return None 43 | 44 | 45 | def list_of_dicts_to_dict_of_lists(list_of_dict): 46 | dict_of_lists = {} 47 | for key in list_of_dict[0].keys(): 48 | dict_of_lists[key] = [d[key] for d in list_of_dict] 49 | return dict_of_lists 50 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itsnamgyu/reasoning-teacher/a2ca4d28d3bbabbd77106b76d06885d1a5eac0d9/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/completion_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | CompletionDataset class and related functions to load and save generated completions from openai or custom models. 3 | 4 | Format: { 5 | "metadata": { 6 | ... 7 | }, 8 | "data": { 9 | sample_index: [ # ← list of "completion samples" (prediction) for each "sample" (question in the dataset). 10 | { # ← completion sample dict 11 | "sample_index": int, 12 | "completion_index": int, 13 | "question": str, 14 | "answer": str, 15 | "reasoning_prompt": str, # for zero-shot-cot 16 | "reasoning_completion": str, # for zero-shot-cot 17 | "reasoning_finish_reason": str, # for openai 18 | "prompt": str, 19 | "completion": str, 20 | "finish_reason": str, # for openai 21 | }, 22 | ... 23 | ], 24 | ... 25 | } 26 | } 27 | """ 28 | 29 | import json 30 | import os 31 | from typing import Dict, List 32 | 33 | from easydict import EasyDict 34 | 35 | from paths import get_completion_data_path 36 | 37 | 38 | class CompletionIdentifier: 39 | """ 40 | Shorthand for CompletionDatasetIdentifier. Contains all information to identify a completion dataset, i.e., the 41 | path of the finetune data file. 42 | """ 43 | 44 | def __init__(self, base_model: str, completion_key: str, dataset_key: str, 45 | train_key: str = None, epoch: int = None): 46 | self.base_model = base_model 47 | self.completion_key = completion_key 48 | self.dataset_key = dataset_key 49 | self.train_key = train_key 50 | self.epoch = epoch 51 | 52 | def __repr__(self): 53 | return "{}_{}_{}_{}_{}".format(self.base_model, self.completion_key, self.dataset_key, 54 | "NAN" if self.train_key is None else self.train_key, 55 | "NAN" if self.epoch is None else self.epoch) 56 | 57 | def __str__(self): 58 | return self.__repr__() 59 | 60 | def __eq__(self, other: "CompletionIdentifier"): 61 | return self.base_model == other.base_model and self.completion_key == other.completion_key and \ 62 | self.dataset_key == other.dataset_key and self.train_key == other.train_key and \ 63 | self.epoch == other.epoch 64 | 65 | @property 66 | def data_path(self): 67 | return get_completion_data_path(self.base_model, self.completion_key, self.dataset_key, 68 | self.train_key, self.epoch) 69 | 70 | 71 | class CompletionMetadata(CompletionIdentifier): 72 | """ 73 | Contains the minimum metadata that constitutes a valid CompletionDataset. Uses include 74 | - passing completion information to a LightningModule to generate validation data 75 | """ 76 | 77 | def __init__(self, base_model: str, completion_key: str, dataset_key: str, finetune_key: str = None, 78 | prediction_template: str = None, train_key: str = None, epoch: int = None): 79 | super().__init__(base_model, completion_key, dataset_key, train_key, epoch) 80 | self.finetune_key = finetune_key 81 | self.prediction_template = prediction_template 82 | 83 | 84 | class CompletionDataset: 85 | def __init__(self, raw_data: Dict): 86 | self.metadata: EasyDict = EasyDict(raw_data["metadata"]) 87 | self.data: Dict[int, List] = {int(k): v for k, v in raw_data["data"].items()} 88 | 89 | @property 90 | def base_model(self): 91 | return self.metadata["base_model"] 92 | 93 | @property 94 | def dataset_key(self): 95 | return self.metadata["dataset_key"] 96 | 97 | @property 98 | def completion_key(self): 99 | return self.metadata["completion_key"] 100 | 101 | @property 102 | def finetune_key(self): 103 | return self.metadata["finetune_key"] 104 | 105 | @property 106 | def train_key(self): 107 | return self.metadata["train_key"] 108 | 109 | @property 110 | def epoch(self): 111 | return self.metadata["epoch"] 112 | 113 | @property 114 | def prediction_template(self): 115 | return self.metadata["prediction_template"] 116 | 117 | def __len__(self): 118 | return len(self.data) 119 | 120 | @property 121 | def total_samples(self): 122 | return sum([len(v) for v in self.data.values()]) 123 | 124 | def __getitem__(self, sample_index: int) -> List: 125 | return self.data[sample_index] 126 | 127 | @property 128 | def indices(self): 129 | return list(self.data.keys()) 130 | 131 | def select_samples(self, sample_indices: List[int] = None, completion_indices: List[int] = None, 132 | only_correct=False) -> List[dict]: 133 | """ 134 | Filter and retrieve completions based on given indices and correctness. 135 | 136 | - sample_indices: List of sample indices to select completions from. 137 | - completion_indices: List of completion indices to select for each sample 138 | - only_correct: If True, only correct completions are returned. Used for CoT dataset generation. 139 | 140 | Return: List of completions. 141 | """ 142 | if sample_indices is None: 143 | sample_indices = list(self.data.keys()) 144 | else: 145 | unavailable = list(set(sample_indices) - set(self.data.keys())) 146 | if unavailable: 147 | raise Exception("Unavailable sample indices including {}".format(unavailable[:5])) 148 | 149 | if only_correct: 150 | evaluator = self.get_evaluator() 151 | 152 | completions = [] 153 | for s in sample_indices: 154 | candidates = [] 155 | samples = self.data[s] 156 | if completion_indices is None: # add all completions 157 | candidates += samples 158 | else: # check if completions exist and add 159 | for c in completion_indices: 160 | if len(samples) <= c: 161 | raise ValueError("Completion #{} for sample #{} not found.".format(c, s)) 162 | candidates += [samples[c]] 163 | for c in candidates: 164 | if not only_correct or evaluator.check_answer(c["completion"], c["answer"]): 165 | completions.append(c) 166 | 167 | return completions 168 | 169 | @property 170 | def path(self): 171 | return get_completion_data_path(self.base_model, self.completion_key, self.dataset_key, self.train_key, 172 | self.epoch) 173 | 174 | def save(self): 175 | raw_data = { 176 | "metadata": self.metadata, 177 | "data": self.data 178 | } 179 | os.makedirs(os.path.dirname(self.path), exist_ok=True) 180 | with open(self.path, "w") as f: 181 | json.dump(raw_data, f, indent=4) 182 | 183 | @staticmethod 184 | def init(completion_metadata: CompletionMetadata, additional_metadata: Dict = None): 185 | raw_data = { 186 | "metadata": { 187 | "base_model": completion_metadata.base_model, 188 | "dataset_key": completion_metadata.dataset_key, 189 | "completion_key": completion_metadata.completion_key, 190 | "finetune_key": completion_metadata.finetune_key, 191 | "train_key": completion_metadata.train_key, 192 | "epoch": completion_metadata.epoch, 193 | "prediction_template": completion_metadata.prediction_template, 194 | }, 195 | "data": {} 196 | } 197 | if additional_metadata: 198 | raw_data["metadata"].update(additional_metadata) 199 | return CompletionDataset(raw_data) 200 | 201 | @staticmethod 202 | def load(completion_identifier: CompletionIdentifier): 203 | base_model = completion_identifier.base_model 204 | dataset_key = completion_identifier.dataset_key 205 | completion_key = completion_identifier.completion_key 206 | train_key = completion_identifier.train_key 207 | epoch = completion_identifier.epoch 208 | 209 | with open(get_completion_data_path(base_model, completion_key, dataset_key, train_key, epoch)) as f: 210 | raw_data = json.load(f) 211 | completions = CompletionDataset(raw_data) 212 | 213 | if completions.base_model != base_model: 214 | raise Exception("Base model mismatch.") 215 | if completions.dataset_key != dataset_key: 216 | raise Exception("Dataset key mismatch.") 217 | if completions.completion_key != completion_key: 218 | raise Exception("Completion key mismatch.") 219 | if completions.train_key != train_key: 220 | raise Exception("Train key mismatch.") 221 | if completions.epoch != epoch: 222 | raise Exception("Epoch mismatch.") 223 | 224 | return completions 225 | 226 | def get_evaluator(self): 227 | from evaluation.evaluator import Evaluator 228 | return Evaluator(self.dataset_key, self.metadata["prediction_template"]) 229 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset class to load original benchmark datasets 3 | 4 | Format: { 5 | "metadata": { 6 | "dataset_key": str, 7 | }, 8 | "data": [ 9 | { 10 | "sample_index": int, 11 | "question": str, 12 | "answer": str, 13 | }, 14 | ... 15 | ] 16 | } 17 | """ 18 | 19 | import json 20 | from typing import Dict, List 21 | 22 | from easydict import EasyDict 23 | 24 | from paths import get_dataset_path 25 | 26 | DATASET_SIZES = { 27 | "single_eq": 508, 28 | "addsub": 395, 29 | "multiarith": 600, 30 | "gsm8k": 8792, 31 | "aqua": 97721, 32 | "svamp": 1000, 33 | 34 | "tracking_shuffled_objects": 750, 35 | "date_understanding": 369, 36 | "coin_flip": 500, 37 | "last_letter_concatenation": 500, 38 | 39 | "commonsense_qa": 10962, 40 | "strategy_qa": 2290, 41 | } 42 | 43 | DATASET_KEYS = list(DATASET_SIZES.keys()) 44 | 45 | 46 | class Dataset: 47 | def __init__(self, raw_data: Dict): 48 | self.raw_data = raw_data 49 | self.metadata: EasyDict = EasyDict(raw_data["metadata"]) 50 | self.data: List[Dict] = raw_data["data"] 51 | 52 | @property 53 | def dataset_key(self): 54 | return self.metadata["dataset_key"] 55 | 56 | def __len__(self): 57 | return len(self.data) 58 | 59 | def __getitem__(self, item: int) -> Dict: 60 | return self.data[item] 61 | 62 | @staticmethod 63 | def load(dataset_key): 64 | with open(get_dataset_path(dataset_key), "r") as f: 65 | raw_data = json.load(f) 66 | dataset = Dataset(raw_data) 67 | 68 | if dataset.metadata["dataset_key"] != dataset_key: 69 | raise Exception("Dataset key mismatch.") 70 | 71 | return dataset 72 | 73 | def select_samples(self, sample_indices: List[int]) -> List[Dict]: 74 | selected = [] 75 | for i in sample_indices: 76 | if len(self.data) <= i: 77 | raise IndexError("Sample index {} out of range [0, {}).".format(i, len(self.data))) 78 | selected.append(self.data[i]) 79 | return selected 80 | -------------------------------------------------------------------------------- /src/data/few_shot_cot_prompt.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | 4 | from paths import FEW_SHOT_COT_PROMPTS_PATH 5 | 6 | 7 | def load_few_shot_cot_prompts(): 8 | with open(FEW_SHOT_COT_PROMPTS_PATH) as f: 9 | return json.load(f) 10 | 11 | 12 | def get_few_shot_cot_prompt(dataset_key) -> str: 13 | data = load_few_shot_cot_prompts() 14 | if dataset_key not in data: 15 | raise KeyError("Few-shot-CoT prompts are not available for dataset `{}`".format(dataset_key)) 16 | return data[dataset_key]["prompt"] 17 | 18 | 19 | def get_few_shot_cot_sample_indices(dataset_key) -> List[int]: 20 | data = load_few_shot_cot_prompts() 21 | if dataset_key not in data: 22 | raise KeyError("Few-shot-CoT prompts are not available for dataset `{}`".format(dataset_key)) 23 | return data[dataset_key]["sample_indices"] 24 | -------------------------------------------------------------------------------- /src/data/format.py: -------------------------------------------------------------------------------- 1 | """ 2 | Format Dataset or CompletionDataset samples for (CoT) inference and fine-tuning. 3 | """ 4 | import copy 5 | import sys 6 | from typing import Dict, List 7 | 8 | from data.few_shot_cot_prompt import get_few_shot_cot_prompt 9 | 10 | SUPPORTED_MODEL_TYPES = ["decoder", "encoder_decoder"] 11 | SUPPORTED_PREDICTION_TEMPLATES = ["ft_token", "ft_cot_natural", "ft_cot_token", "zs", "zs_cot", "fs_cot"] 12 | 13 | 14 | class Formatter: 15 | def __init__(self, model_type: str, prediction_template: str, zs_cot_step: int = None, dataset_key: str = None, 16 | stop_phrase: str = None): 17 | """ 18 | Parameters 19 | 20 | - model_type 21 | - prediction_template 22 | - zs_cot_step 23 | - dataset_key 24 | - stop_phrase: used as `stop` string for OpenAI API. YOU MUST SET THIS VALUE FOR OPENAI FINE-TUNING or the model 25 | will not learn to stop. our experiments. 26 | """ 27 | if model_type not in SUPPORTED_MODEL_TYPES: 28 | raise NotImplementedError("model_type={}".format(model_type)) 29 | if prediction_template not in SUPPORTED_PREDICTION_TEMPLATES: 30 | raise NotImplementedError("prediction_template={}".format(prediction_template)) 31 | 32 | self.model_type = model_type 33 | self.prediction_template = prediction_template 34 | self.zs_cot_step = zs_cot_step 35 | self.dataset_key = dataset_key 36 | self.stop_phrase = stop_phrase # used as `stop` string for OpenAI API 37 | 38 | if prediction_template == "ft_natural": 39 | if model_type == "decoder": 40 | self.input_format = "Q: {sample[question]}\n\nA:" 41 | self.label_format = " {sample[answer]}" 42 | # REPRODUCTION NOTE - GPT3 experiments in the paper use the following: 43 | # self.input_format = "{sample[question]}\n\n ### \n\n" 44 | # self.label_format = " {sample[answer]}" 45 | elif model_type == "encoder_decoder": 46 | self.input_format = "Q: {sample[question]}" 47 | self.label_format = "{sample[answer]}" 48 | 49 | if prediction_template == "ft_token": 50 | if model_type == "decoder": 51 | self.input_format = "{sample[question]} ###" 52 | self.label_format = " {sample[answer]}" 53 | # REPRODUCTION NOTE - GPT3 experiments in the paper use the following: 54 | # self.input_format = "{sample[question]}\n\n ### \n\n" 55 | # self.label_format = " {sample[answer]}" 56 | elif model_type == "encoder_decoder": 57 | self.input_format = "{sample[question]}" 58 | self.label_format = "{sample[answer]}" 59 | 60 | if prediction_template == "ft_cot_token": 61 | if model_type == "decoder": 62 | self.input_format = "{sample[question]} ###" 63 | self.label_format = " {sample[reasoning_completion]} --> {sample[answer]}" 64 | # REPRODUCTION NOTE - GPT3 experiments in the paper use the following: 65 | # self.input_format = "{sample[question]}\n\n###\n\n" 66 | # self.label_format = " {sample[reasoning_completion]}\n\n-->\n\n{sample[answer]}" 67 | elif model_type == "encoder_decoder": 68 | self.input_format = "{sample[question]}" 69 | self.label_format = "{sample[reasoning_completion]} --> {sample[answer]}" 70 | 71 | if prediction_template == "ft_cot_natural": 72 | if model_type == "decoder": 73 | self.input_format = "Q: {sample[question]}\n\nA: Let's think step by step.\n\n" 74 | self.label_format = " {sample[reasoning_completion]}\n\nTherefore the answer is {sample[answer]}" 75 | # REPRODUCTION NOTE - GPT3 experiments in the paper use the following: 76 | # self.input_format = "Q: {sample[question]}\n\nA: Let's think step by step.\n\n" 77 | # self.label_format = " {sample[reasoning_completion]}\n\nTherefore the answer is\n\n{sample[answer]}" 78 | else: 79 | raise NotImplementedError("model_type={} not supported for prediction_template={}".format( 80 | model_type, prediction_template)) 81 | 82 | if prediction_template == "zs": 83 | self.label_format = None 84 | if model_type == "decoder": 85 | self.input_format = "Q: {sample[question]}\n\nA:" 86 | elif model_type == "encoder_decoder": # following SQuAD format used to train T5 87 | self.input_format = "question: {sample[question]}" 88 | else: 89 | raise NotImplementedError("model_type={} not supported for zs_cot".format(model_type)) 90 | 91 | if prediction_template == "zs_cot": 92 | self.label_format = None 93 | if model_type == "decoder": 94 | if zs_cot_step == 1: 95 | self.input_format = "Q: {sample[question]}\nA: Let's think step by step." 96 | elif zs_cot_step == 2: 97 | self.input_format = "{sample[reasoning_prompt]}{sample[reasoning_completion]}\nTherefore, the answer is" 98 | else: 99 | raise ValueError("step {} not supported for zs_cot".format(zs_cot_step)) 100 | else: 101 | raise NotImplementedError("model_type={} not supported for zs_cot".format(model_type)) 102 | 103 | if prediction_template == "fs_cot": 104 | self.label_format = None 105 | if dataset_key is None: 106 | raise ValueError("dataset_key must be specified for fs_cot") 107 | self.few_shot_prompt = get_few_shot_cot_prompt(dataset_key) 108 | 109 | if model_type == "decoder": 110 | self.input_format = self.few_shot_prompt + "\nQ: {sample[question]}\nA:" 111 | elif model_type == "encoder_decoder": 112 | self.input_format = self.few_shot_prompt + "\nQ: {sample[question]}\nA:" 113 | else: 114 | raise NotImplementedError("model_type={} not supported for fs_cot".format(model_type)) 115 | 116 | if not hasattr(self, "input_format"): 117 | raise NotImplementedError(f"{model_type}, {prediction_template}") 118 | if not hasattr(self, "label_format"): 119 | # should be set to None if not supported 120 | raise NotImplementedError(f"{model_type}, {prediction_template}") 121 | 122 | def __call__(self, sample: Dict, include_label: bool = False): 123 | """ 124 | Sample can either be a dataset sample (from Dataset) or completion sample (from CompletionDataset). 125 | Samples should contain all necessary keys needed for self.prediction_template, e.g., zs_cot step 2 requires 126 | reasoning_completion. 127 | """ 128 | sample = copy.deepcopy(sample) 129 | 130 | # REPRODUCTION NOTE - stripping is not applied to GPT3 experiments in the paper. 131 | if "question" in sample: 132 | sample["question"] = sample["question"].strip() 133 | if "answer" in sample: 134 | sample["answer"] = sample["answer"].strip() 135 | if "reasoning_completion" in sample: 136 | sample["reasoning_completion"] = sample["reasoning_completion"].strip() 137 | if "reasoning_prompt" in sample: 138 | sample["reasoning_prompt"] = sample["reasoning_prompt"].strip() 139 | 140 | result = { 141 | "sample_index": sample["sample_index"], 142 | "input": self.input_format.format(sample=sample), 143 | } 144 | 145 | if include_label: 146 | if self.label_format is None: 147 | raise ValueError("label formatting is not supported for prediction_template={}".format( 148 | self.prediction_template)) 149 | result["label"] = self.label_format.format(sample=sample) 150 | 151 | if self.stop_phrase is not None: 152 | result["label"] += " " + self.stop_phrase 153 | 154 | return result 155 | 156 | def format_samples(self, samples: List[Dict], include_labels: bool = False) -> List[Dict]: 157 | """ 158 | - samples: List of samples from Dataset or CompletionDataset. Use the `select_samples()` method 159 | provided by either class. 160 | """ 161 | formatted_samples = [] 162 | 163 | errors = 0 164 | for sample in samples: 165 | try: 166 | formatted_sample = self(sample, include_label=include_labels) 167 | except ValueError as e: 168 | errors += 1 169 | continue 170 | formatted_samples.append(formatted_sample) 171 | 172 | if errors > 0: 173 | print("ERROR: {}/{} samples could not be formatted".format(errors, len(samples)), file=sys.stderr) 174 | print("Raising last Exception", file=sys.stderr) 175 | raise e 176 | 177 | return formatted_samples 178 | -------------------------------------------------------------------------------- /src/data/generate_split.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used to generate splits. This should only be used for initial generation (by the author, or contributors 3 | who add new dataset or splits). Subsequent runs should use the splits saved in `data/splits/` to ensure consistency. 4 | 5 | Note, `np.random.RandomState` is said to guarantee consistency across environments, but we do this just to be safe. 6 | """ 7 | import json 8 | import os 9 | from typing import Optional, Tuple, List 10 | 11 | import numpy as np 12 | 13 | from data.dataset import DATASET_SIZES 14 | from paths import get_split_path 15 | 16 | 17 | def get_default_train_test_split(dataset_key) -> Optional[Tuple[List[int], List[int]]]: 18 | predefined = get_predefined_train_test_split(dataset_key) 19 | if predefined is not None: 20 | return predefined 21 | 22 | return get_random_train_test_indices(dataset_key) 23 | 24 | 25 | def get_predefined_train_test_split(dataset_key) -> Optional[Tuple[List[int], List[int]]]: 26 | if dataset_key == "aqua": 27 | train_size = 97467 28 | test_size = 254 29 | train_subsample_seed = 0 30 | train_subsample_size = 10000 31 | train_indices = list(range(train_size)) 32 | state = np.random.RandomState(train_subsample_seed) 33 | train_indices = sorted(state.permutation(train_indices)[:train_subsample_size].tolist()) 34 | test_indices = list(range(train_size, train_size + test_size)) 35 | return train_indices, test_indices 36 | if dataset_key == "gsm8k": 37 | train_size = 7473 38 | test_size = 1319 39 | indices = list(range(train_size + test_size)) 40 | return indices[:train_size], indices[train_size:] 41 | if dataset_key == "commonsense_qa": 42 | train_size = 9741 43 | test_size = 1221 44 | indices = list(range(train_size + test_size)) 45 | return indices[:train_size], indices[train_size:] 46 | 47 | return None 48 | 49 | 50 | def get_random_train_test_indices(dataset_key: str, train_ratio=0.7, split_seed=0) -> Tuple[ 51 | List[int], List[int]]: 52 | dataset_size = DATASET_SIZES[dataset_key] 53 | 54 | indices = list(range(dataset_size)) 55 | state = np.random.RandomState(split_seed) 56 | indices = state.permutation(indices) 57 | train_n = round(dataset_size * train_ratio) 58 | train_indices = sorted(indices[:train_n].tolist()) 59 | test_indices = sorted(indices[train_n:].tolist()) 60 | 61 | return train_indices, test_indices 62 | 63 | 64 | if __name__ == "__main__": 65 | for dataset_key in DATASET_SIZES: 66 | path = get_split_path(dataset_key, "default") 67 | if os.path.exists(path): 68 | print("Skipping `{}`. Split file already exists at: {}".format(dataset_key, path)) 69 | continue 70 | 71 | train, test = get_default_train_test_split(dataset_key) 72 | os.makedirs(os.path.dirname(path), exist_ok=True) 73 | with open(path, "w") as f: 74 | json.dump({"train": train, "test": test}, f, indent=4) 75 | print("Saved split for `{}` to: {}".format(dataset_key, path)) 76 | -------------------------------------------------------------------------------- /src/data/split.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List, Tuple 4 | 5 | import numpy as np 6 | 7 | from paths import get_split_path 8 | 9 | 10 | def load_train_test_split(dataset_key: str) -> Tuple[List[int], List[int]]: 11 | split = load_split(dataset_key) 12 | return split["train"], split["test"] 13 | 14 | 15 | def load_split(dataset_key: str, split_key="default"): 16 | path = get_split_path(dataset_key, split_key) 17 | if os.path.exists(path): 18 | with open(path) as f: 19 | return json.load(f) 20 | else: 21 | raise ValueError("Split {} for dataset {} does not exist.".format(split_key, dataset_key)) 22 | 23 | 24 | def subsample_indices(indices: List[int], n: int, split_seed=0): 25 | """ 26 | Sort, permute and select first `n` indices. 27 | Used for 8, 32, 128shot ablations in paper. 28 | """ 29 | if n > len(indices): 30 | print("Warning: n == {} > len(indices) == {}".format(n, len(indices))) 31 | 32 | state = np.random.RandomState(split_seed) 33 | indices = state.permutation(indices) 34 | return sorted(indices[:n].tolist()) 35 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itsnamgyu/reasoning-teacher/a2ca4d28d3bbabbd77106b76d06885d1a5eac0d9/src/evaluation/__init__.py -------------------------------------------------------------------------------- /src/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Tuple, Union, Optional, Dict 3 | 4 | import pandas as pd 5 | 6 | from data.completion_dataset import CompletionDataset 7 | 8 | PREDICTION_PREFIXES = { 9 | None: None, 10 | "zs": None, 11 | "ft_natural": None, 12 | "ft_token": None, 13 | "fs_cot": "The answer is", 14 | "zs_cot": None, 15 | "ft_cot_natural": "Therefore, the answer is", 16 | "ft_cot_token": "-->", 17 | } 18 | 19 | 20 | class Evaluator: 21 | dataset_key: str 22 | prediction_template: Optional[str] 23 | prediction_prefix: Optional[str] 24 | 25 | def __init__(self, dataset_key: str, prediction_template: str): 26 | """ 27 | Set prediction_template=None if you are only using the evaluator to parse answers. 28 | """ 29 | self.dataset_key = dataset_key 30 | self.prediction_template = prediction_template 31 | if prediction_template not in PREDICTION_PREFIXES: 32 | raise ValueError("Invalid prediction template: {}".format(prediction_template)) 33 | else: 34 | self.prediction_prefix = PREDICTION_PREFIXES[prediction_template] 35 | 36 | @staticmethod 37 | def for_completion_dataset(completion_dataset: CompletionDataset) -> "Evaluator": 38 | return Evaluator(completion_dataset.dataset_key, completion_dataset.prediction_template) 39 | 40 | @staticmethod 41 | def evaluate_completion_dataset(completion_dataset: CompletionDataset, sample_indices: List[int] = None, 42 | completion_indices: List[int] = None) -> pd.DataFrame: 43 | """ 44 | Evaluate a set of completions (i.e. a CompletionData object). 45 | 46 | - indices: If not None, only evaluate completions for the given sample indices. 47 | - completion_indices: If not None, only evaluate the completions with the specified indices for each sample, 48 | e.g., to evaluate repeated completions with temperature sampling. 49 | """ 50 | evaluator = Evaluator.for_completion_dataset(completion_dataset) 51 | completions = completion_dataset.select_samples(sample_indices, completion_indices) 52 | evaluations = [] 53 | for completion in completions: 54 | evaluations.append(evaluator.evaluate_completion(completion)) 55 | 56 | return pd.DataFrame(evaluations) 57 | 58 | def evaluate_completion(self, completion: Dict) -> Dict: 59 | """ 60 | Evaluate a single prediction. 61 | """ 62 | completion_string = completion["completion"] 63 | correct_format = self.prediction_prefix is None or completion_string.find(self.prediction_prefix) != -1 64 | prediction, candidates = self.cleanse_prediction(completion_string, return_all=True) 65 | answer = self.cleanse_answer(completion["answer"]) 66 | return { 67 | "sample_index": completion["sample_index"], 68 | "completion_index": completion["completion_index"], 69 | "correct": self._compare_prediction_and_answer(prediction, answer), 70 | "contains_answer": any(self._compare_prediction_and_answer(p, answer) for p in candidates), 71 | "correct_format": correct_format, 72 | "complete": completion.get("finish_reason") == "stop", 73 | } 74 | 75 | def check_answer(self, completion_string: str, answer: str) -> bool: 76 | """ 77 | Check if a single prediction is correct. 78 | """ 79 | prediction = self.cleanse_prediction(completion_string, return_all=False) 80 | answer = self.cleanse_answer(answer) 81 | return self._compare_prediction_and_answer(prediction, answer) 82 | 83 | def cleanse_prediction(self, completion: str, return_all: bool) -> Union[str, Tuple[str, List[str]]]: 84 | if self.prediction_prefix is None: 85 | # If no prefix, use first candidate 86 | predictions = self._extract_prediction_candidates(completion) 87 | first = True 88 | else: 89 | index = completion.find(self.prediction_prefix) 90 | if index == -1: 91 | # If prefix not found, use *last* candidate 92 | predictions = self._extract_prediction_candidates(completion) 93 | first = False 94 | else: 95 | # If prefix found, use *first* candidate after prefix 96 | start_of_answer = index + len(self.prediction_prefix) 97 | predictions = self._extract_prediction_candidates(completion[start_of_answer:]) 98 | first = True 99 | 100 | answer = None 101 | if predictions: 102 | answer = (predictions[0] if first else predictions[-1]) 103 | 104 | return (answer, predictions) if return_all else answer 105 | 106 | def cleanse_answer(self, answer: str) -> str: 107 | if self.dataset_key in ["gsm8k", "addsub", "multiarith", "svamp", "single_eq"]: 108 | answer = answer.replace(",", "") 109 | if self.dataset_key == "strategy_qa": 110 | answer = answer.lower() 111 | if self.dataset_key in ["addsub", "svamp", "single_eq"]: 112 | answer = float(answer) 113 | 114 | return answer 115 | 116 | def _extract_prediction_candidates(self, prediction: str) -> List[str]: 117 | """ 118 | Extracts all potential answer predictions which satisfy the dataset's answer format from the 119 | prediction string 120 | """ 121 | if self.dataset_key in ("aqua", "commonsense_qa"): 122 | prediction = re.findall(r'[ABCDE]', prediction) 123 | elif self.dataset_key == "date_understanding": 124 | prediction = re.findall(r'[ABCDEF]', prediction) 125 | elif self.dataset_key in ("tracking_shuffled_objects"): 126 | prediction = re.findall(r'[ABC]', prediction) 127 | elif self.dataset_key in ("gsm8k", "addsub", "multiarith", "svamp", "single_eq"): 128 | prediction = prediction.replace(",", "") 129 | prediction = re.findall(r'-?\d+(?:\.\d+)?', prediction) 130 | if self.dataset_key in ("addsub", "svamp", "single_eq"): 131 | prediction = [float(s) for s in prediction] 132 | elif self.dataset_key in ("strategy_qa", "coin_flip"): 133 | prediction = prediction.lower() 134 | prediction = re.sub("\"|\'|\n|\.|\s|\:|\,", " ", prediction) 135 | prediction = prediction.split(" ") 136 | prediction = [i for i in prediction if i in ("yes", "no")] 137 | elif self.dataset_key == "last_letter_concatenation": 138 | prediction = re.sub("\"|\'|\n|\.|\s", "", prediction) 139 | prediction = [prediction] 140 | else: 141 | raise ValueError("Invalid dataset: {}".format(self.dataset_key)) 142 | 143 | return prediction 144 | 145 | def _compare_prediction_and_answer(self, prediction, answer) -> bool: 146 | if self.dataset_key in ("addsub", "svamp", "single_eq"): 147 | return prediction is not None and abs(prediction - answer) <= 1e-6 148 | else: 149 | return prediction is not None and prediction == answer 150 | -------------------------------------------------------------------------------- /src/evaluation/summary.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pandas as pd 4 | 5 | 6 | def summarize_evaluation(evaluation: pd.DataFrame) -> dict: 7 | """ 8 | Summarize metrics from completion-wise evaluation dataframe. 9 | Dataframe contains columns "sample_index", "completion_index", "correct", "contains_answer", "correct_format", "complete", 10 | """ 11 | if evaluation is None or len(evaluation) == 0: 12 | warnings.warn("No completions to evaluate.") 13 | return None 14 | 15 | return { 16 | "accuracy": evaluation.correct.mean(), 17 | "contains_answer": evaluation.contains_answer.mean(), 18 | "correct_format": evaluation.correct_format.mean(), 19 | "complete": evaluation.complete.mean(), 20 | } 21 | 22 | -------------------------------------------------------------------------------- /src/oai/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | OpenAI experiment source code 3 | """ -------------------------------------------------------------------------------- /src/oai/finetune.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List, Dict 4 | 5 | from data.completion_dataset import CompletionDataset 6 | from data.dataset import Dataset 7 | from data.format import Formatter 8 | from evaluation.evaluator import Evaluator 9 | from oai.utils.api_wrapper import create_finetune_file, create_finetune 10 | from oai.utils.metadata import get_model_key 11 | from paths import get_finetune_data_path 12 | 13 | STOP_PHRASE = "END" 14 | 15 | 16 | def init_finetune(finetune_key: str, base_model: str, dataset_key: str, train_key: str, 17 | finetune_kwargs: Dict = None) -> str: 18 | """ 19 | Creates a `File` (containing the finetune data) and a `Finetune` (on that file) on OpenAI. 20 | 21 | Resulting `Model`s can be fetched after a Finetune is completed. Refer to `oai.utils.fetch_model_ids` to fetch 22 | models. 23 | 24 | Return model_key 25 | """ 26 | create_finetune_file(finetune_key) 27 | model_key = get_model_key(base_model, dataset_key, train_key) 28 | if finetune_kwargs is None: 29 | finetune_kwargs = {} 30 | create_finetune(finetune_key, base_model, dataset_key, train_key, **finetune_kwargs) 31 | 32 | return model_key 33 | 34 | 35 | def generate_finetune_data_from_completion_dataset(completion_dataset: CompletionDataset, 36 | prediction_template: str, 37 | finetune_key: str, 38 | sample_indices: List[int] = None, 39 | completion_indices: List[int] = None, 40 | only_correct=True): 41 | """ 42 | Generate 43 | """ 44 | formatter = Formatter("decoder", prediction_template, dataset_key=completion_dataset.dataset_key, 45 | stop_phrase=STOP_PHRASE) 46 | samples = completion_dataset.select_samples(sample_indices, completion_indices) 47 | 48 | if only_correct: 49 | evaluator = Evaluator(completion_dataset.dataset_key, completion_dataset.prediction_template) 50 | 51 | finetune_data = [] 52 | for sample in samples: 53 | if only_correct: 54 | if not evaluator.evaluate_completion(sample)["correct"]: 55 | continue 56 | formatted = formatter(sample, include_label=True) 57 | finetune_data.append({ 58 | "prompt": formatted["input"], 59 | "completion": formatted["label"], 60 | }) 61 | 62 | _save_finetune_data(finetune_data, finetune_key) 63 | 64 | 65 | def generate_finetune_data_from_dataset(dataset: Dataset, 66 | prediction_template: str, 67 | finetune_key: str, 68 | sample_indices: List[int] = None): 69 | formatter = Formatter("decoder", prediction_template, dataset_key=dataset.dataset_key, 70 | stop_phrase=STOP_PHRASE) 71 | samples = dataset.select_samples(sample_indices) 72 | 73 | finetune_data = [] 74 | for sample in samples: 75 | formatted = formatter(sample, include_label=True) 76 | finetune_data.append({ 77 | "prompt": formatted["input"], 78 | "completion": formatted["label"], 79 | }) 80 | 81 | _save_finetune_data(finetune_data, finetune_key) 82 | 83 | 84 | def _save_finetune_data(data: List[Dict], finetune_key): 85 | path = get_finetune_data_path("openai", finetune_key) 86 | print("Saving {} fine-tuning samples to {}".format(len(data), path)) 87 | 88 | lines = [] 89 | for sample in data: 90 | lines.append(json.dumps(sample)) 91 | full_string = "\n".join(lines) 92 | 93 | if os.path.exists(path): 94 | with open(path, "r") as f: 95 | existing = f.read() 96 | if existing != full_string: 97 | raise Exception("Finetune data already exists and is different from the given data.") 98 | else: 99 | os.makedirs(os.path.dirname(path), exist_ok=True) 100 | with open(path, "w") as f: 101 | f.write(full_string) 102 | -------------------------------------------------------------------------------- /src/oai/inference.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from tqdm import tqdm 4 | 5 | from data.completion_dataset import CompletionMetadata, CompletionDataset 6 | from data.dataset import Dataset 7 | from data.format import Formatter 8 | from oai.utils.api_wrapper import create_completion 9 | from oai.utils.metadata import get_model_key, get_model_id 10 | 11 | STOP_PHRASE = "END" 12 | 13 | 14 | def batch_infer_samples(samples: List[Dict], model_id: str, key_prefix: None, batch_size: int = 20, 15 | temperature: float = 0, max_tokens: int = 128, 16 | save_completion_dataset: CompletionDataset = None): 17 | """ 18 | Complete samples using OpenAI models in batches, in-place. 19 | 20 | - All samples should contain a prompt with the key "_prompt". 21 | - Completions will be added with the key "_completion". 22 | - finish_reason's will be added with the key "_finish_reason". 23 | 24 | - saved_completion_dataset: if provided, will be saved every time a batch is completed. `samples` should 25 | contain references to the samples in this CompletionDataset, or else the new completions will not be saved. 26 | """ 27 | # Prepend key_prefix to keys, e.g., "reasoning" for zs_cot step 1 28 | prompt_key = f"{key_prefix}_prompt" if key_prefix else "prompt" 29 | completion_key = f"{key_prefix}_completion" if key_prefix else "completion" 30 | finish_reason_key = f"{key_prefix}_finish_reason" if key_prefix else "finish_reason" 31 | 32 | for sample in samples: 33 | if prompt_key not in sample: 34 | raise ValueError( 35 | f"Sample #{sample['sample_index']} - {sample['completion_index']} does not contain {prompt_key}") 36 | 37 | all_samples = samples # keep a reference to all samples 38 | samples = [s for s in samples if completion_key not in s] # filter out already completed samples for inference 39 | 40 | if len(samples) == 0: 41 | print("All {} samples have been completed.".format(len(all_samples))) 42 | else: 43 | print("Inferring completions for {} remaining samples (total={})".format(len(samples), len(all_samples))) 44 | 45 | pbar = tqdm(total=len(samples), desc="Inferring completions via OpenAI") 46 | for i in range(0, len(samples), batch_size): 47 | batch_samples = samples[i:i + batch_size] 48 | prompts = [s[prompt_key] for s in batch_samples] 49 | response = create_completion(model=model_id, prompt=prompts, max_tokens=max_tokens, 50 | temperature=temperature, n=1, stop=STOP_PHRASE) 51 | assert len(response["choices"]) == len(batch_samples) 52 | 53 | completions = [c["text"] for c in response["choices"]] 54 | finish_reasons = [c["finish_reason"] for c in response["choices"]] 55 | for sample in batch_samples: 56 | sample[completion_key] = completions.pop(0) 57 | sample[finish_reason_key] = finish_reasons.pop(0) 58 | 59 | if save_completion_dataset: 60 | save_completion_dataset.save() 61 | 62 | pbar.update(len(batch_samples)) 63 | 64 | return all_samples 65 | 66 | 67 | def populate_completion_dataset(completion_dataset: CompletionDataset, dataset: Dataset, 68 | formatter: Formatter, sample_indices: List[int] = None, augs: int = 1, 69 | prompt_key: str = "prompt"): 70 | if sample_indices is None: 71 | sample_indices = list(range(len(dataset.data))) 72 | 73 | dataset_samples = dataset.select_samples(sample_indices) 74 | for i, datsaet_sample in zip(sample_indices, dataset_samples): 75 | # Add sample lists 76 | if i in completion_dataset.data: 77 | completion_samples = completion_dataset.data.get(i) 78 | else: 79 | completion_samples = list() 80 | completion_dataset.data[i] = completion_samples 81 | 82 | # Add completion sample dicts 83 | remaining = augs - len(completion_samples) 84 | for _ in range(remaining): 85 | completion_samples.append(dict()) 86 | 87 | dataset_sample = dataset.data[i] 88 | 89 | # Populate completion sample dicts 90 | for j, completion_sample in enumerate(completion_samples): 91 | completion_sample["sample_index"] = i 92 | completion_sample["completion_index"] = j 93 | completion_sample["question"] = dataset_sample["question"] 94 | completion_sample["answer"] = dataset_sample["answer"] 95 | if formatter.prediction_template == "zs_cot" and formatter.zs_cot_step == 2: 96 | if "reasoning_completion" not in completion_sample: 97 | raise ValueError( 98 | "All samples must contain a 'reasoning_completion' key for zs_cot step 2. Make sure to run step 1 for the same sample/completion indices") 99 | completion_sample[prompt_key] = formatter(completion_sample, include_label=False)["input"] 100 | 101 | 102 | def infer_completion_data(completion_metadata: CompletionMetadata, zs_cot_step: int = None, 103 | sample_indices: List[int] = None, augs: int = 1, 104 | temperature: float = 0, max_tokens: int = 128): 105 | """ 106 | Init/load CompletionDataset, infer completions for remaining samples, and save. 107 | 108 | - sample_indices: indices of samples to infer, or None for all samples 109 | - augs: number of completion_indices per sample 110 | """ 111 | model_key = get_model_key(completion_metadata.base_model, completion_metadata.dataset_key, 112 | completion_metadata.train_key) 113 | model_id = get_model_id(model_key) 114 | if model_id is None: 115 | raise ValueError(f"OpenAI model with model_key=`{model_key}` does not exist") 116 | 117 | formatter = Formatter("decoder", completion_metadata.prediction_template, zs_cot_step, 118 | completion_metadata.dataset_key, stop_phrase=STOP_PHRASE) 119 | 120 | # If running zs_cot step 1, add "reasoning" prefix to keys 121 | if completion_metadata.prediction_template == "zs_cot" and zs_cot_step == 1: 122 | key_prefix = "reasoning" 123 | else: 124 | key_prefix = None 125 | prompt_key = f"{key_prefix}_prompt" if key_prefix else "prompt" 126 | temperature_key = f"{key_prefix}_temperature" if key_prefix else "temperature" 127 | max_tokens_key = f"{key_prefix}_max_tokens" if key_prefix else "max_tokens" 128 | 129 | if completion_metadata.prediction_template == "zs_cot" and zs_cot_step is None: 130 | raise ValueError("zs_cot_step must be specified for prediction_template='zs_cot'") 131 | 132 | # Load dataset 133 | dataset = Dataset.load(completion_metadata.dataset_key) 134 | if sample_indices is None: 135 | sample_indices = list(range(len(dataset.data))) 136 | 137 | # Load or init CompletionDataset 138 | try: 139 | completion_dataset = CompletionDataset.load(completion_metadata) 140 | print("Loaded {} samples from:".format(completion_dataset.total_samples)) 141 | print(completion_dataset.path) 142 | except FileNotFoundError: 143 | completion_dataset = CompletionDataset.init(completion_metadata, additional_metadata={ 144 | temperature_key: temperature, 145 | max_tokens_key: max_tokens, 146 | }) 147 | print("Initializing new CompletionDataset at:") 148 | print(completion_dataset.path) 149 | 150 | # Populate CompletionDataset with formatted prompts, etc. 151 | populate_completion_dataset(completion_dataset, dataset, formatter, sample_indices, augs, prompt_key=prompt_key) 152 | completion_dataset.save() 153 | 154 | # Get list of individual completion sample dicts 155 | completion_indices = list(range(augs)) 156 | completion_samples = completion_dataset.select_samples(sample_indices, completion_indices) 157 | 158 | # Infer completions 159 | batch_infer_samples(completion_samples, model_id, key_prefix, batch_size=20, temperature=temperature, 160 | max_tokens=max_tokens, save_completion_dataset=completion_dataset) 161 | 162 | return completion_dataset 163 | -------------------------------------------------------------------------------- /src/oai/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/itsnamgyu/reasoning-teacher/a2ca4d28d3bbabbd77106b76d06885d1a5eac0d9/src/oai/utils/__init__.py -------------------------------------------------------------------------------- /src/oai/utils/api_wrapper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import traceback 5 | from datetime import datetime 6 | 7 | import openai 8 | 9 | from oai.utils.metadata import FINETUNE_IDS_PATH, MODEL_IDS_PATH, set_model_id, get_model_key 10 | from oai.utils.metadata import get_file_id, set_file_id, get_finetune_id, set_finetune_id 11 | from paths import get_finetune_data_path, SAVED_PATH 12 | 13 | OPENAI_ERROR_LOG_PATH = os.path.join(SAVED_PATH, "openai_error_log.txt") 14 | 15 | 16 | def log_openai_error(message: str): 17 | timestamp = datetime.now().astimezone().isoformat() 18 | os.makedirs(os.path.dirname(OPENAI_ERROR_LOG_PATH), exist_ok=True) 19 | with open(OPENAI_ERROR_LOG_PATH, "a") as f: 20 | f.write(" {} ".format(timestamp).center(80, "#")) 21 | f.write("\n") 22 | f.write(message) 23 | f.write("\n") 24 | f.write("\n") 25 | 26 | 27 | def get_openai_errors(lines=50): 28 | with open(OPENAI_ERROR_LOG_PATH) as f: 29 | if lines > 0: 30 | return "".join(f.readlines()[-lines:]) 31 | else: 32 | return f.read() 33 | 34 | 35 | def create_completion(*args, verbose=True, error_while=None, **kwargs): 36 | retry_intervals = [0] * 1 + [1] * 5 + [10, 30, 60, 300] 37 | 38 | for i, t in enumerate(retry_intervals): 39 | if t: 40 | time.sleep(t) 41 | try: 42 | response = openai.Completion.create(*args, **kwargs) 43 | return response 44 | except Exception as e: 45 | if verbose: 46 | print("Error during OpenAI completion attempt #{}: [{}] {}".format(i + 1, type(e).__name__, str(e))) 47 | if error_while is not None: 48 | log_openai_error("Error during {} attempt #{}:\n{}".format(error_while, i + 1, traceback.format_exc())) 49 | else: 50 | log_openai_error(traceback.format_exc()) 51 | else: 52 | return None 53 | 54 | 55 | def create_finetune_file(finetune_key: str, overwrite=False): 56 | if get_file_id(finetune_key) is not None and not overwrite: 57 | print("Warning: OpenAI File `{}` already exists (likely already uploaded). Skipping.".format( 58 | finetune_key)) 59 | return 60 | 61 | path = get_finetune_data_path("openai", finetune_key) 62 | if not os.path.exists(path): 63 | raise FileNotFoundError("Finetune data file with file_key `{}` not found at: `{}`".format(finetune_key, path)) 64 | 65 | with open(path) as f: 66 | response = openai.File.create( 67 | file=f, 68 | purpose='fine-tune' 69 | ) 70 | file_id = response["id"] 71 | set_file_id(finetune_key, file_id) 72 | print("Created OpenAI File for `{}`: `{}`".format(finetune_key, response["id"])) 73 | 74 | return file_id 75 | 76 | 77 | def create_finetune(file_key: str, base_model: str, dataset_key: str, train_key: str, ignore_existing=False, 78 | **kwargs): 79 | model_key = get_model_key(base_model, dataset_key, train_key) 80 | if get_finetune_id(model_key) is not None and not ignore_existing: 81 | print("Warning: OpenAI Finetune for `{}` already exists. Skipping.".format(model_key)) 82 | return 83 | 84 | file_id = get_file_id(file_key) 85 | if file_id is None: 86 | raise KeyError("OpenAI File with file_id `{}` does not exist".format(file_key)) 87 | 88 | response = openai.FineTune.create(training_file=file_id, model=base_model, **kwargs) 89 | finetune_id = response["id"] 90 | set_finetune_id(model_key, finetune_id) 91 | print("Created OpenAI finetune `{}`: `{}`".format(model_key, finetune_id)) 92 | 93 | return finetune_id 94 | 95 | 96 | def fetch_model_ids(): 97 | """ 98 | Fetches model ids for all finetunes that have been completed 99 | """ 100 | if os.path.exists(FINETUNE_IDS_PATH): 101 | with open(FINETUNE_IDS_PATH) as f: 102 | finetune_ids = json.load(f) 103 | else: 104 | raise FileNotFoundError( 105 | "Finetune ids metadata file is missing. Create a finetune using `oai.api_wrapper.create_finetune`") 106 | 107 | if os.path.exists(MODEL_IDS_PATH): 108 | with open(MODEL_IDS_PATH) as f: 109 | model_ids = json.load(f) 110 | else: 111 | model_ids = {} 112 | 113 | model_keys_to_fetch = [] 114 | status_by_key = {} 115 | total = 0 116 | done = 0 117 | for model_key, finetune_id in finetune_ids.items(): 118 | if model_key not in model_ids: 119 | model_keys_to_fetch.append(model_key) 120 | status_by_key[model_key] = "pending" 121 | total += 1 122 | 123 | if total == 0: 124 | print("No model ids to fetch") 125 | return True 126 | 127 | print("Fetching model ids from {} finetunes".format(len(model_keys_to_fetch))) 128 | print("-" * 100) 129 | print("{:<80s}{:<20s}".format("model_key", "status")) 130 | print("-" * 100) 131 | for model_key in model_keys_to_fetch: 132 | finetune_id = finetune_ids[model_key] 133 | response = openai.FineTune.retrieve(finetune_id) 134 | model_id = response["fine_tuned_model"] 135 | if model_id is not None: 136 | set_model_id(model_key, model_id) 137 | done += 1 138 | print("{:<80s}{:<20s}".format(model_key, response["status"])) 139 | print("-" * 100) 140 | print("Fetched {} of {} model ids".format(done, total)) 141 | 142 | return done == total 143 | -------------------------------------------------------------------------------- /src/oai/utils/fetch_model_ids.py: -------------------------------------------------------------------------------- 1 | from oai.utils.api_wrapper import fetch_model_ids 2 | 3 | 4 | def main(): 5 | fetch_model_ids() 6 | 7 | 8 | if __name__ == '__main__': 9 | main() 10 | -------------------------------------------------------------------------------- /src/oai/utils/metadata.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from paths import SAVED_PATH 5 | 6 | OPENAI_METADATA_PATH = os.path.join(SAVED_PATH, "openai_metadata") 7 | FINETUNE_IDS_PATH = os.path.join(OPENAI_METADATA_PATH, "finetune_ids.json") 8 | MODEL_IDS_PATH = os.path.join(OPENAI_METADATA_PATH, "model_ids.json") 9 | FILE_IDS_PATH = os.path.join(OPENAI_METADATA_PATH, "file_ids.json") 10 | 11 | DEFAULT_MODEL_IDS = [ 12 | "ada", 13 | "babbage", 14 | "curie", 15 | "davinci", 16 | "text-davinci-001", 17 | "text-davinci-002", 18 | "text-davinci-003", 19 | ] 20 | 21 | 22 | def get_json_value(json_path: str, key: str): 23 | if os.path.exists(json_path): 24 | with open(json_path) as f: 25 | data = json.load(f) 26 | return data.get(key, None) 27 | else: 28 | return None 29 | 30 | 31 | def set_json_value(json_path: str, key: str, value=None, on_exist="overwrite"): 32 | if os.path.exists(json_path): 33 | with open(json_path, "r") as f: 34 | data = json.load(f) 35 | else: 36 | data = {} 37 | 38 | existing = data.get(key, None) 39 | if existing is None: 40 | if value is None: 41 | del data[key] 42 | else: 43 | data[key] = value 44 | else: 45 | if on_exist == "overwrite": 46 | if value is None: 47 | del data[key] 48 | else: 49 | data[key] = value 50 | elif on_exist == "check_equals": 51 | if existing != value: 52 | raise ValueError("Value mismatch: {} != {}".format(existing, value)) 53 | elif on_exist == "ignore": 54 | pass 55 | else: 56 | raise ValueError("Unsupported argument for `on_exist`: {}".format(on_exist)) 57 | 58 | os.makedirs(os.path.dirname(json_path), exist_ok=True) 59 | with open(json_path, "w") as f: 60 | json.dump(data, f, indent=4) 61 | 62 | 63 | def get_model_key(base_model_key: str, dataset_key: str, train_key: str): 64 | if train_key is None: 65 | return base_model_key 66 | else: 67 | return "B_{}__D_{}__T_{}".format(base_model_key, dataset_key, train_key) 68 | 69 | 70 | def get_finetune_id(model_key): 71 | return get_json_value(FINETUNE_IDS_PATH, model_key) 72 | 73 | 74 | def set_finetune_id(model_key, finetune_id, on_exist="overwrite"): 75 | set_json_value(FINETUNE_IDS_PATH, model_key, finetune_id, on_exist) 76 | 77 | 78 | def get_model_id(model_key): 79 | if model_key in DEFAULT_MODEL_IDS: 80 | return model_key 81 | return get_json_value(MODEL_IDS_PATH, model_key) 82 | 83 | 84 | def set_model_id(model_key, model_id, on_exist="overwrite"): 85 | if model_key in DEFAULT_MODEL_IDS: 86 | raise KeyError("Cannot overwrite default model: {}".format(model_key)) 87 | set_json_value(MODEL_IDS_PATH, model_key, model_id, on_exist) 88 | 89 | 90 | def get_file_id(finetune_key): 91 | return get_json_value(FILE_IDS_PATH, finetune_key) 92 | 93 | 94 | def set_file_id(finetune_key, file_id, on_exist="overwrite"): 95 | set_json_value(FILE_IDS_PATH, finetune_key, file_id, on_exist) 96 | -------------------------------------------------------------------------------- /src/oai/utils/tokens.py: -------------------------------------------------------------------------------- 1 | tokenizer = None 2 | 3 | 4 | def get_tokenizer(): 5 | global tokenizer 6 | if tokenizer is None: 7 | from transformers import GPT2TokenizerFast 8 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 9 | return tokenizer 10 | 11 | 12 | def get_token_count(text: str): 13 | return len(get_tokenizer().tokenize(text)) 14 | 15 | 16 | def truncate_by_n_tokens(text: str, n: int) -> str: 17 | tokenizer = get_tokenizer() 18 | return tokenizer.decode(tokenizer(text)["input_ids"][:n]) 19 | -------------------------------------------------------------------------------- /src/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | 5 | DATA_PATH = os.path.join(PROJECT_PATH, "data") 6 | DATASET_PATH = os.path.join(DATA_PATH, "dataset") 7 | SPLIT_PATH = os.path.join(DATA_PATH, "split") 8 | FEW_SHOT_COT_PROMPTS_PATH = os.path.join(DATA_PATH, "few_shot_cot_prompts.json") 9 | 10 | SAVED_PATH = os.path.join(PROJECT_PATH, "saved") 11 | FINETUNE_DATA_PATH = os.path.join(SAVED_PATH, "finetune_data") 12 | COMPLETION_DATA_PATH = os.path.join(SAVED_PATH, "completion_data") 13 | 14 | 15 | def get_dataset_path(dataset_key: str) -> str: 16 | return os.path.join(DATASET_PATH, "{}.json".format(dataset_key)) 17 | 18 | 19 | def get_split_path(dataset_key: str, split_key: str) -> str: 20 | return os.path.join(SPLIT_PATH, "{}__{}.json".format(dataset_key, split_key)) 21 | 22 | 23 | def get_completion_data_path(base_model: str, completion_key: str, dataset_key: str, 24 | train_key: str = None, epoch: int = None) -> str: 25 | dirname = "B_{}__C_{}".format(base_model, completion_key) 26 | base_tags = ["D_{}".format(dataset_key)] 27 | if train_key is not None: 28 | base_tags.append("T_{}".format(train_key)) 29 | if epoch is not None: 30 | base_tags.append("E_{:03d}".format(epoch)) 31 | basename = "__".join(base_tags) 32 | 33 | return os.path.join(COMPLETION_DATA_PATH, dirname, "{}.json".format(basename)) 34 | 35 | 36 | def get_finetune_data_path(platform_key: str, finetune_key: str) -> str: 37 | if platform_key == "openai": 38 | return os.path.join(FINETUNE_DATA_PATH, "P_{}".format(platform_key), "F_{}.jsonl".format(finetune_key)) 39 | else: 40 | return os.path.join(FINETUNE_DATA_PATH, "P_{}".format(platform_key), "F_{}.json".format(finetune_key)) 41 | --------------------------------------------------------------------------------