├── .flake8 ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── better_promptability ├── __init__.py ├── check_install.py ├── common │ ├── __init__.py │ └── testing.py ├── data │ ├── __init__.py │ ├── config.py │ ├── data_module.py │ ├── data_utils.py │ ├── mixer_dataloader.py │ ├── mixer_dataset.py │ ├── prompt_data_module.py │ ├── t0_meta_learning_data_module.py │ ├── t0_mixture.py │ ├── t0_module.py │ └── t0_multitask_data_module.py ├── models │ ├── __init__.py │ ├── meta_learner.py │ ├── model.py │ ├── prefix_transformer.py │ └── t5_with_prefix.py ├── modules │ ├── __init__.py │ ├── transformer.py │ └── with_prefix_embedding.py ├── steps │ ├── __init__.py │ ├── process_dataset.py │ └── process_story_cloze.py ├── train │ ├── __init__.py │ ├── aggregate_results.py │ ├── eval.py │ ├── optim.py │ ├── train.py │ └── train_main.py └── version.py ├── configs ├── 0shot_eval.jsonnet ├── 0shot_eval_all_d4_dev.jsonnet ├── 0shot_eval_all_green.jsonnet ├── check_install.yml ├── fewshot_eval.jsonnet ├── fewshot_eval_all_d4_dev.jsonnet ├── fewshot_eval_all_green.jsonnet ├── fomaml.jsonnet ├── multi_task.jsonnet ├── reptile.jsonnet ├── t0_mixtures.jsonnet └── t0_task_info.jsonnet ├── data ├── .gitkeep ├── d4_dev_training_indices_16shot_100seed.pkl └── green_training_indices_16shot_100seed.pkl ├── mypy.ini ├── output └── .gitkeep ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── scripts ├── bootstrap.py ├── download_t0_training_set.py ├── process_green_datasets.py └── subsample_t0_training_set.py ├── setup.py ├── tango.yml ├── test_fixtures ├── configs │ ├── check.jsonnet │ ├── d4_dev.jsonnet │ ├── d4_train.jsonnet │ └── green.jsonnet └── data │ ├── cache │ ├── adversarial_qa_dbert_based_on │ │ ├── dataset_dict.json │ │ ├── train │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ │ └── validation │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ ├── hellaswag_complete_first_then_score_eval │ │ ├── dataset_dict.json │ │ ├── test │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ │ ├── train │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ │ └── validation │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ ├── hellaswag_if_begins_how_continues_score_eval │ │ ├── dataset_dict.json │ │ ├── test │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ │ ├── train │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ │ └── validation │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ ├── openbookqa_main_choices │ │ ├── dataset_dict.json │ │ ├── train │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ │ └── validation │ │ │ ├── dataset.arrow │ │ │ ├── dataset_info.json │ │ │ └── state.json │ └── story_cloze_2016_Story_Continuation_and_Options_score_eval │ │ ├── dataset_dict.json │ │ ├── test │ │ ├── dataset.arrow │ │ ├── dataset_info.json │ │ └── state.json │ │ └── validation │ │ ├── dataset.arrow │ │ ├── dataset_info.json │ │ └── state.json │ └── processed_cache │ ├── hellaswag_complete_first_then_score_eval │ ├── dataset_dict.json │ ├── test │ │ ├── dataset.arrow │ │ ├── dataset_info.json │ │ └── state.json │ ├── train │ │ ├── dataset.arrow │ │ ├── dataset_info.json │ │ └── state.json │ └── validation │ │ ├── dataset.arrow │ │ ├── dataset_info.json │ │ └── state.json │ └── story_cloze_2016_Story_Continuation_and_Options_score_eval │ ├── dataset_dict.json │ ├── train │ ├── dataset.arrow │ ├── dataset_info.json │ └── state.json │ └── validation │ ├── dataset.arrow │ ├── dataset_info.json │ └── state.json └── tests ├── __init__.py ├── configs_test.py ├── data ├── __init__.py ├── mixer_dataset_test.py └── t0_data_module_test.py ├── hello_test.py ├── models └── __init__.py ├── modules ├── __init__.py └── transformer_test.py └── steps ├── __init__.py ├── process_dataset_test.py └── process_story_cloze_test.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 115 3 | 4 | ignore = 5 | # these rules don't play well with black 6 | # whitespace before : 7 | E203 8 | # line break before binary operator 9 | W503 10 | 11 | exclude = 12 | .venv 13 | .git 14 | __pycache__ 15 | .mypy_cache 16 | 17 | per-file-ignores = 18 | # __init__.py files are allowed to have unused imports and lines-too-long 19 | */__init__.py:F401 20 | */**/**/__init__.py:F401,E501 21 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.ref }} 5 | cancel-in-progress: true 6 | 7 | on: 8 | pull_request: 9 | branches: 10 | - main 11 | push: 12 | branches: 13 | - main 14 | 15 | env: 16 | # Change this to invalidate existing cache. 17 | CACHE_PREFIX: v2 18 | PYTHON_PATH: ./ 19 | DEFAULT_PYTHON: 3.7 20 | 21 | jobs: 22 | checks: 23 | name: python ${{ matrix.python }} - ${{ matrix.task.name }} 24 | runs-on: [ubuntu-latest] 25 | timeout-minutes: 30 26 | strategy: 27 | fail-fast: false 28 | matrix: 29 | python: [3.7] 30 | task: 31 | - name: Style 32 | run: | 33 | black --check . 34 | 35 | - name: Lint 36 | run: | 37 | flake8 . 38 | 39 | - name: Type check 40 | run: | 41 | mypy . 42 | 43 | - name: Test 44 | run: | 45 | pytest -v --color=yes tests/ 46 | 47 | steps: 48 | - uses: actions/checkout@v2 49 | 50 | - name: Setup Python 51 | uses: actions/setup-python@v2 52 | with: 53 | python-version: ${{ matrix.python }} 54 | 55 | - name: Install prerequisites 56 | run: | 57 | pip install --upgrade pip setuptools wheel virtualenv 58 | 59 | - name: Set build variables 60 | shell: bash 61 | run: | 62 | # Get the exact Python version to use in the cache key. 63 | echo "PYTHON_VERSION=$(python --version)" >> $GITHUB_ENV 64 | echo "RUNNER_ARCH=$(uname -m)" >> $GITHUB_ENV 65 | # Use week number in cache key so we can refresh the cache weekly. 66 | echo "WEEK_NUMBER=$(date +%V)" >> $GITHUB_ENV 67 | 68 | - uses: actions/cache@v2 69 | id: virtualenv-cache 70 | with: 71 | path: .venv 72 | key: ${{ env.CACHE_PREFIX }}-${{ env.WEEK_NUMBER }}-${{ runner.os }}-${{ env.RUNNER_ARCH }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }} 73 | 74 | - name: Setup virtual environment (no cache hit) 75 | if: steps.virtualenv-cache.outputs.cache-hit != 'true' 76 | run: | 77 | test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv 78 | . .venv/bin/activate 79 | pip install torch==1.10.1+cpu torchvision==0.11.2+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html 80 | pip install -e .[dev] 81 | 82 | - name: Setup virtual environment (cache hit) 83 | if: steps.virtualenv-cache.outputs.cache-hit == 'true' 84 | run: | 85 | . .venv/bin/activate 86 | pip install --no-deps -e .[dev] 87 | 88 | - name: Show environment info 89 | run: | 90 | . .venv/bin/activate 91 | which python 92 | python --version 93 | pip freeze 94 | 95 | - name: Check install 96 | run: | 97 | . .venv/bin/activate 98 | tango run configs/check_install.yml 99 | 100 | - name: ${{ matrix.task.name }} 101 | run: | 102 | . .venv/bin/activate 103 | ${{ matrix.task.run }} 104 | 105 | - name: Clean up 106 | if: always() 107 | run: | 108 | . .venv/bin/activate 109 | pip uninstall -y better_promptability 110 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # build artifacts 2 | 3 | .eggs/ 4 | .mypy_cache 5 | *.egg-info/ 6 | build/ 7 | dist/ 8 | pip-wheel-metadata/ 9 | 10 | 11 | # dev tools 12 | 13 | .envrc 14 | .python-version 15 | .idea 16 | .venv/ 17 | .vscode/ 18 | /*.iml 19 | 20 | 21 | # jupyter notebooks 22 | 23 | .ipynb_checkpoints 24 | 25 | 26 | # miscellaneous 27 | 28 | .cache/ 29 | *.datacache/* 30 | doc/_build/ 31 | *.swp 32 | .DS_Store 33 | 34 | 35 | # python 36 | 37 | *.pyc 38 | *.pyo 39 | __pycache__ 40 | 41 | 42 | # testing and continuous integration 43 | 44 | .coverage 45 | .pytest_cache/ 46 | .benchmarks 47 | 48 | # documentation build artifacts 49 | 50 | docs/build 51 | site/ 52 | 53 | # cache 54 | *.datacache 55 | 56 | output/* 57 | !output/.gitkeep 58 | runs/ 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continued Pretraining for Better Zero- and Few-Shot Promptability 2 | 3 | The official implementation for our paper (http://arxiv.org/abs/2210.10258): 4 | 5 | ```bibtex 6 | @inproceedings{wu-etal-2022-continued, 7 | title = "Continued Pretraining for Better Zero- and Few-Shot Promptability", 8 | author = "Zhaofeng Wu and Robert L. Logan IV and Pete Walsh and Akshita Bhagia and Dirk Groeneveld and Sameer Singh and Iz Beltagy", 9 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 10 | month = dec, 11 | year = "2022", 12 | publisher = "Association for Computational Linguistics", 13 | } 14 | ``` 15 | 16 | We provide a somewhat cleaned version of the codebase in the `main` branch. If you run into any issue, you can check out the `archive` branch for the original version that we used. For historical reasons, this repository is slightly over-engineered. For example, because eventually we just performed continued pretraining for one epoch, a lot of checkpointing related logic is unused. 17 | 18 | Feel free to open an issue if you have any questions. 19 | 20 | ## Pretrained Models 21 | 22 | We release our pretrained models at https://huggingface.co/ZhaofengWu/better-promptability. 23 | 24 | ## Environment Setup 25 | 26 | 1. Create a new Python virtual environment with Python 3.7. 27 | 2. Install PyTorch 1.10.1 according to the [official instructions](https://pytorch.org/get-started/locally/). You may need to install `torchvision==0.11.2` this way too. 28 | 3. Run 29 | 30 | ``` 31 | pip install -e . 32 | ``` 33 | Sometimes you might need the flags `--trusted-host=pypi.python.org --trusted-host=pypi.org --trusted-host=files.pythonhosted.org`. 34 | 35 | You can verify that your environment is set up properly by running: 36 | 37 | ``` 38 | tango run configs/check_install.yml 39 | ``` 40 | 41 | ## Data Preparation 42 | 43 | We use the P3 datasets for training and evaluation. In our codebase, we refer to the training datasets as `d4_train`, following the naming convention in their codebase, and the evaluation datasets as `green`, because they are colored green in the T0 paper. You may also see mentions of `d4_dev`, which is a set of datasets (mutually exclusive with `d4_train` and `green`) that we used for development. 44 | 45 | Most of these datasets are publicly available, with the exception of Story Cloze, which we separately obtained from BigScience. You could try doing the same, or processing the data yourself from the original source. The processed Story Cloze data should be in a directory with folders `story_cloze_2016_{Answer_Given_options,Choose_Story_Ending,Movie_What_Happens_Next,Novel_Correct_Ending,Story_Continuation_and_Options}_score_eval`, each one with files 46 | 47 | ``` 48 | COMPLETED info.test.json info.validation.json stats.test.json stats.validation.json test.tfrecord-00000-of-00001 validation.tfrecord-00000-of-00001 49 | ``` 50 | 51 | You should update the `STORY_CLOZE_PATH` variable in `scripts/download_t0_training_set.py` to point to this directory. Then to download and process the rest of the datasets, you can run the following commands. Depending on your network speed, etc., they could take a few days (~2 days on our machine). 52 | 53 | ```bash 54 | mkdir t0_cache unprocessed_green_cache 55 | python scripts/download_t0_training_set.py d4_train t0_cache 56 | python scripts/download_t0_training_set.py green unprocessed_green_cache 57 | python scripts/process_green_datasets.py unprocessed_green_cache t0_cache 58 | ``` 59 | 60 | ## Training 61 | 62 | All existing configs use T5-small for illustration. You might want to replace it with other sized T5 models. 63 | 64 | ### Continued Pretraining 65 | 66 | Change the value of `"t0_data_cache"` in each config to the path to the `t0_cache` directory above. Then you can run multi-task training or meta-learning with one of the following commands. When run for the first time, these commands may take a few hours for further dataset processing. 67 | 68 | ```bash 69 | tango run -d ${continued_pretrained_model} configs/multi_task.jsonnet 70 | tango run -d ${continued_pretrained_model} configs/fomaml.jsonnet 71 | tango run -d ${continued_pretrained_model} configs/reptile.jsonnet 72 | ``` 73 | 74 | For multi-task training, you can change the flags `"train_full_model"`, `"num_prefix"`, and `"deep"`, to reproduce our various configurations in the paper. By default, the config file reproduces our best model that trains all components, with a deep prompt. Feel free to change the other flags too -- in particular, you probably want to change the number of GPUs used. These scripts support distributed training. Note that the tqdm estimates of these scripts are over-estimations in the beginning. Wait for at least >10% or so for a more accurate estimate. 75 | 76 | ### 0-shot/few-shot Evaluation 77 | 78 | For 0-shot/few-shot evaluation, you can run (remember to set the `"t0_data_cache"` path, like above): 79 | 80 | ```bash 81 | CKPT=${checkpoint_path} tango run -d ${output_dir} configs/0shot_eval_all_green.jsonnet # or configs/fewshot_eval_all_green.jsonnet 82 | ``` 83 | 84 | where `${checkpoint_path}` is the checkpoint you want to evaluate in `${continued_pretrained_model}`. It should look something like `${continued_pretrained_model}/cache/TrainStep-*/work/epoch=0-step=*-endofepoch-categorical_accuracy=*.ckpt`. Set `CKPT=null` if you want to evaluate the model without any continued pretraining. 85 | 86 | You need to set the flags `"model_name"`, `"num_prefix"`, and `"deep"` to match the values used during continued pretraining. For example, for the model `mtl_large_deep`, you want `"model_name" = "google/t5-large-lm-adapt"`, `"num_prefix" = 20`, and `"deep" = true`. 87 | 88 | `configs/0shot_eval.jsonnet` and `configs/fewshot_eval.jsonnet` evaluate a speicific dataset instead of aggregating over all datasets. 89 | 90 | These configs don't directly print out the ARG. To compute that, you can print out the per-dataset accuracy using something like `for d in $(ls -d ${output_dir}/runs/*/result_* | sort); do cat $d/data.json | python -c "import sys, json; print(json.load(sys.stdin)[1][-1]['best_categorical_accuracy'], end='')"; echo -n ","; done`, and then paste the resulting string into `boostrap.py` for the ARG. `bootstrap.py` is also used for significance testing. 91 | 92 | ### Evaluating Official T0 Checkpoints 93 | 94 | T0 was trained without EOS (at least so it seems). To accomodate for this, change `t0_module.py`'s' `assemble_prompt()` to not add EOS (in addition to, of course, changing the `"model_name"` to T0 in the relevant config). 95 | -------------------------------------------------------------------------------- /better_promptability/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/better_promptability/__init__.py -------------------------------------------------------------------------------- /better_promptability/check_install.py: -------------------------------------------------------------------------------- 1 | from tango import Step 2 | 3 | 4 | @Step.register("check_install") 5 | class CheckInstall(Step): 6 | DETERMINISTIC = True 7 | CACHEABLE = False 8 | 9 | def run(self) -> None: 10 | import torch 11 | 12 | if torch.cuda.is_available(): 13 | print("All good! CUDA is available :)") 14 | else: 15 | print("All good! No CUDA though :/") 16 | -------------------------------------------------------------------------------- /better_promptability/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/better_promptability/common/__init__.py -------------------------------------------------------------------------------- /better_promptability/common/testing.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from copy import deepcopy 3 | import logging 4 | import os 5 | import shutil 6 | import tempfile 7 | from pathlib import Path 8 | from typing import List, Dict, Any, Optional, cast, Union 9 | 10 | from tango.common.registrable import Registrable 11 | from tango.common.util import PathOrStr 12 | 13 | 14 | class BetterPromptabilityTestCase: 15 | """ 16 | A custom testing class that 17 | 18 | * disables some of the more verbose logging, 19 | * creates and destroys a temp directory as a test fixture, and 20 | * restores the internal state of the `Registrable` class at the end of each test method. 21 | 22 | """ 23 | 24 | PROJECT_ROOT = (Path(__file__).parent / ".." / "..").resolve() 25 | """ 26 | Root of the git repository. 27 | """ 28 | 29 | MODULE_ROOT = PROJECT_ROOT / "better_promptability" 30 | """ 31 | Root of the tango module. 32 | """ 33 | 34 | TESTS_ROOT = PROJECT_ROOT / "tests" 35 | """ 36 | Root of the tests directory. 37 | """ 38 | 39 | FIXTURES_ROOT = PROJECT_ROOT / "test_fixtures" 40 | """ 41 | Root of the test fixtures directory. 42 | """ 43 | 44 | def setup_method(self): 45 | logging.basicConfig( 46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.DEBUG 47 | ) 48 | 49 | # Disabling some of the more verbose logging statements that typically aren't very helpful 50 | # in tests. 51 | logging.getLogger("urllib3.connectionpool").disabled = True 52 | 53 | # Create a temporary scratch directory. 54 | self.TEST_DIR = Path(tempfile.mkdtemp(prefix="better_promptability_tests")) 55 | os.makedirs(self.TEST_DIR, exist_ok=True) 56 | 57 | @classmethod 58 | def setup_class(cls): 59 | # During teardown we'll restore the state of `Registrable`'s internal registry 60 | # to make sure any registered mock test classes are removed so they don't conflict 61 | # with other tests. 62 | cls._original_registry = deepcopy(Registrable._registry) 63 | 64 | def teardown_method(self): 65 | shutil.rmtree(self.TEST_DIR) 66 | 67 | @classmethod 68 | def teardown_class(cls): 69 | Registrable._registry = cls._original_registry 70 | 71 | def run( 72 | self, 73 | config: Union[PathOrStr, Dict[str, Any]], 74 | overrides: Optional[Union[Dict[str, Any], str]] = None, 75 | include_package: Optional[List[str]] = None, 76 | ) -> Path: 77 | from .params import Params 78 | from tango.__main__ import _run, TangoGlobalSettings 79 | 80 | if isinstance(config, dict): 81 | params = Params(config) 82 | config = self.TEST_DIR / "config.json" 83 | params.to_file(cast(Path, config)) 84 | 85 | if isinstance(overrides, dict): 86 | import json 87 | 88 | overrides = json.dumps(overrides) 89 | 90 | run_dir = self.TEST_DIR / "run" 91 | _run( 92 | TangoGlobalSettings(), 93 | str(config), 94 | directory=str(run_dir), 95 | overrides=overrides, 96 | include_package=include_package, 97 | ) 98 | return run_dir 99 | 100 | 101 | @contextmanager 102 | def run_experiment( 103 | config: Union[PathOrStr, Dict[str, Any]], overrides: Optional[Union[Dict[str, Any], str]] = None 104 | ): 105 | """ 106 | A context manager to make testing experiments easier. On ``__enter__`` it runs 107 | the experiment and returns the path to the cache directory, a temporary directory that will be 108 | cleaned up on ``__exit__``. 109 | """ 110 | test_case = BetterPromptabilityTestCase() 111 | try: 112 | test_case.setup_method() 113 | yield test_case.run(config, overrides=overrides) 114 | finally: 115 | test_case.teardown_method() 116 | -------------------------------------------------------------------------------- /better_promptability/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .t0_mixture import T0Mixture 2 | from .t0_module import T0Module 3 | -------------------------------------------------------------------------------- /better_promptability/data/config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from tango.common.aliases import PathOrStr 4 | from tango.common.registrable import Registrable 5 | 6 | 7 | class Config(Registrable): 8 | def __init__( 9 | self, 10 | seed: int = 42, 11 | gpus: int = 1, 12 | precision: Union[int, str] = 32, 13 | output_dir: Optional[PathOrStr] = None, 14 | auto_select_gpus: bool = True, 15 | ): 16 | self.seed = seed 17 | self.precision = precision 18 | self.gpus = gpus # TODO: do stuff with visible devices. 19 | self.output_dir = output_dir 20 | self.auto_select_gpus = auto_select_gpus 21 | 22 | 23 | Config.register("default")(Config) 24 | -------------------------------------------------------------------------------- /better_promptability/data/data_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import logging 3 | import os 4 | from abc import abstractmethod, abstractproperty 5 | from collections.abc import ItemsView 6 | from typing import Any, Mapping, Optional, Union 7 | 8 | from allennlp.training.metrics import Metric 9 | import datasets 10 | from datasets import Dataset as HFDataset, DatasetDict as HFDatasetDict 11 | from tango.common import DatasetDict as TangoDatasetDict 12 | from tango.common.aliases import PathOrStr 13 | from tango.integrations.pytorch_lightning.data import LightningDataModule 14 | from torch.utils.data import DataLoader 15 | from transformers import PreTrainedTokenizerBase 16 | from transformers.trainer_pt_utils import LengthGroupedSampler, DistributedLengthGroupedSampler 17 | 18 | from .config import Config 19 | from .data_utils import PAD_TYPE, collate_fn as default_collate_fn, md5 20 | from .mixer_dataset import MixerDataset 21 | 22 | 23 | # Sometimes we want to change the implementation of methods, etc., which cache ignores. 24 | # We maintain our own cache so this is not very useful anyway. 25 | datasets.set_caching_enabled(False) 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | DatasetDictType = Union[TangoDatasetDict, HFDatasetDict] 32 | 33 | 34 | class DataModule(LightningDataModule): 35 | """ 36 | Abstract class representing a lightning data module using HF datasets, relevant properties, 37 | and a tokenizer. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | config: Config, 43 | data_dir: Optional[PathOrStr] = None, 44 | max_length: Optional[int] = None, 45 | preprocess_and_save: bool = True, 46 | batch_size: int = 32, 47 | eval_batch_size: int = 32, 48 | num_workers: int = 1, 49 | ): 50 | super().__init__() 51 | self.config = config 52 | self.data_dir = data_dir or "/tmp/better-promptability/data-dir" 53 | self.max_length = max_length 54 | self.preprocess_and_save = preprocess_and_save 55 | self.batch_size = batch_size 56 | self.eval_batch_size = eval_batch_size 57 | self.num_workers = num_workers 58 | self._tokenizer: Optional[PreTrainedTokenizerBase] = None 59 | 60 | def setup(self, stage: Optional[str] = None): 61 | if self.preprocess_and_save: 62 | if os.path.exists(self.cache_path): 63 | self.dataset_dict = HFDatasetDict.load_from_disk(self.cache_path) 64 | return 65 | 66 | self.dataset_dict = self.load() 67 | if self.preprocess_and_save: 68 | self.dataset_dict = self.preprocess(self.dataset_dict) 69 | logger.info(f"Saving dataset cache at {self.cache_path}") 70 | self.dataset_dict.save_to_disk(self.cache_path) 71 | 72 | def _to_params(self): 73 | return {} 74 | 75 | def __getitem__(self, key: str) -> HFDataset: 76 | return self.dataset_dict[key] 77 | 78 | @property 79 | def hash_fields(self) -> list[Any]: 80 | """For cache purpose""" 81 | return [self.config.seed, self.tokenizer.__repr__()] 82 | 83 | @property 84 | def cache_path(self) -> str: 85 | hash_fields = "".join([str(f) for f in self.hash_fields]) 86 | return os.path.join( 87 | self.data_dir, 88 | f"{self.__class__.__name__}_{md5(hash_fields)}.datacache", 89 | ) 90 | 91 | @property 92 | def train_split(self) -> str: 93 | return "train" 94 | 95 | @property 96 | def dev_splits(self) -> list[str]: 97 | return ["dev"] 98 | 99 | @property 100 | def test_splits(self) -> list[str]: 101 | return ["test"] # we don't actually use this 102 | 103 | @property 104 | @abstractproperty 105 | def sort_key(self) -> str: 106 | raise NotImplementedError("This is an abstract property. Did you forget to implement it?") 107 | 108 | @property 109 | @abstractproperty 110 | def metric_names(self) -> list[str]: 111 | raise NotImplementedError("This is an abstract property. Did you forget to implement it?") 112 | 113 | def instantiate_metric(self, metric_name: str, split: str) -> Metric: 114 | return Metric.by_name(metric_name)() 115 | 116 | @property 117 | def metric_to_watch(self) -> str: 118 | if len(self.metric_names) == 1: 119 | return self.metric_names[0] 120 | else: 121 | raise NotImplementedError( 122 | "This is an abstract property. Did you forget to implement it?" 123 | ) 124 | 125 | @property 126 | @abstractproperty 127 | def metric_watch_mode(self) -> str: 128 | raise NotImplementedError("This is an abstract property. Did you forget to implement it?") 129 | 130 | @abstractmethod 131 | def load(self) -> DatasetDictType: 132 | raise NotImplementedError("This is an abstract method. Did you forget to implement it?") 133 | 134 | @abstractmethod 135 | def tokenize(self, examples: dict[str, list], split: str) -> dict[str, list]: 136 | raise NotImplementedError("This is an abstract method. Did you forget to implement it?") 137 | 138 | def preprocess(self, dataset_dict: DatasetDictType) -> DatasetDictType: 139 | logger.info("Begin preprocessing") 140 | assert isinstance(dataset_dict, HFDatasetDict) 141 | dataset_dict = HFDatasetDict( # reimplementing DatasetDict.map to provide `split` 142 | { 143 | split: dataset.map( 144 | lambda examples: self.tokenize(examples, split), 145 | batched=False, # to make tokenization/transformation easier 146 | num_proc=4, 147 | ) 148 | for split, dataset in dataset_dict.items() 149 | } 150 | ) 151 | logger.info("End preprocessing") 152 | 153 | # Rename validation -> dev 154 | if "validation" in dataset_dict and "dev" not in dataset_dict: 155 | dataset_dict["dev"] = dataset_dict["validation"] 156 | del dataset_dict["validation"] 157 | 158 | return dataset_dict 159 | 160 | @property 161 | def tokenizer(self) -> PreTrainedTokenizerBase: 162 | if self._tokenizer is None: 163 | tokenizer = self.setup_tokenizer() 164 | self._tokenizer = tokenizer 165 | return tokenizer 166 | else: 167 | return self._tokenizer 168 | 169 | @tokenizer.setter 170 | def tokenizer(self, tokenizer: PreTrainedTokenizerBase): 171 | self._tokenizer = tokenizer 172 | 173 | @abstractmethod 174 | def setup_tokenizer(self) -> PreTrainedTokenizerBase: 175 | raise NotImplementedError("This is an abstract method. Did you forget to implement it?") 176 | 177 | def items(self) -> ItemsView: 178 | return self.dataset_dict.items() 179 | 180 | def dataloader( 181 | self, split: str, batch_size: int, collate_fn=default_collate_fn 182 | ) -> DataLoader: 183 | dataset_split = self.dataset_dict[split] 184 | 185 | # LengthGroupedSampler sorts from longest to shortest; we want the reverse 186 | if isinstance(dataset_split, MixerDataset): 187 | # The naive processing is slow and takes too much memory 188 | lens = [-l for l in dataset_split.get_all_example_lens()] # noqa: E741 189 | else: 190 | lens = [-len(ids) for ids in dataset_split[self.sort_key]] 191 | if self.config.gpus is None or self.config.gpus <= 1: 192 | sampler = LengthGroupedSampler(batch_size, lengths=lens) 193 | else: 194 | sampler = DistributedLengthGroupedSampler(batch_size, lengths=lens) 195 | 196 | pad_token_map = self.pad_token_map(split) 197 | assert all(pad is not None for pad in pad_token_map.values()) 198 | 199 | dataloader = DataLoader( 200 | dataset_split, 201 | batch_size=batch_size, 202 | shuffle=False, 203 | sampler=sampler, 204 | num_workers=self.num_workers, 205 | collate_fn=lambda batch: collate_fn(batch, pad_token_map, self.tokenizer.padding_side), 206 | pin_memory=True, 207 | ) 208 | 209 | return dataloader 210 | 211 | @abstractmethod 212 | def pad_token_map(self, split: str) -> Mapping[str, PAD_TYPE]: 213 | """ 214 | Specifies the padding for each key. Only keys including in this map will be 215 | included in the batch. 216 | """ 217 | raise NotImplementedError("This is an abstract method. Did you forget to implement it?") 218 | 219 | def train_dataloader(self) -> DataLoader: 220 | return self.dataloader(self.train_split, self.batch_size) 221 | 222 | def val_dataloader(self, shuffle: bool = False): 223 | return [self.dataloader(split, self.eval_batch_size) for split in self.dev_splits] 224 | 225 | def test_dataloader(self, shuffle: bool = False): 226 | return [self.dataloader(split, self.eval_batch_size) for split in self.test_splits] 227 | -------------------------------------------------------------------------------- /better_promptability/data/data_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import hashlib 4 | from typing import Iterable, Mapping, Union 5 | 6 | import math 7 | import numpy as np 8 | import torch 9 | from torch.utils.data._utils.collate import default_collate 10 | 11 | PAD_TYPE = Union[int, float, bool] 12 | 13 | 14 | def _find_max_shapes( 15 | batch: list[dict[str, np.ndarray]], allow_keys: Iterable[str], pad_to_multiples_of_8: bool 16 | ) -> dict[str, np.ndarray]: 17 | max_shapes = {} 18 | for e in batch: 19 | for k, v in e.items(): 20 | if k not in allow_keys: 21 | continue 22 | shape = np.array(v.shape) 23 | if k not in max_shapes: 24 | max_shapes[k] = shape 25 | else: 26 | try: 27 | max_shapes[k] = np.maximum(max_shapes[k], shape) 28 | except ValueError: # more informed error message 29 | raise ValueError(f"Different shapes for {k}: {max_shapes[k]} vs. {shape}") 30 | 31 | if pad_to_multiples_of_8: 32 | for k, v in max_shapes.items(): 33 | max_shapes[k] = np.array([int(math.ceil(i / 8)) * 8 for i in v]) 34 | 35 | return max_shapes 36 | 37 | 38 | def _pad_last_dim(sequence: list[list], padding_token: PAD_TYPE, padding_side: str): 39 | """ 40 | In-place pads the last dimension of a 2d list. 41 | """ 42 | assert padding_side in {"left", "right"} 43 | max_len = max(len(e) for e in sequence) 44 | for i, e in enumerate(sequence): 45 | pad_len = max_len - len(e) 46 | sequence[i] = ( 47 | ([padding_token] * pad_len if padding_side == "left" else []) 48 | + e 49 | + ([padding_token] * pad_len if padding_side == "right" else []) 50 | ) 51 | 52 | 53 | def _pad( 54 | sequence: np.ndarray, padding_token: PAD_TYPE, padding_shape: np.ndarray, padding_side: str 55 | ) -> np.ndarray: 56 | assert padding_side in {"left", "right"} 57 | if sequence is None: 58 | return None 59 | padding = [(p, 0) if padding_side == "left" else (0, p) for p in padding_shape] 60 | return np.pad(sequence, padding, constant_values=padding_token) 61 | 62 | 63 | def _tensorize(sequence: np.ndarray, name: str) -> torch.Tensor: 64 | dtype = torch.long 65 | if "_mask" in name or "is_correct" in name: # TODO: there should be a smarter way to do this 66 | dtype = torch.bool 67 | return torch.tensor(sequence, dtype=dtype) 68 | 69 | 70 | def collate_fn( 71 | batch: list[dict[str, list]], 72 | pad_token_map: Mapping[str, PAD_TYPE], 73 | padding_side: str, 74 | pad_to_multiples_of_8: bool = False, 75 | ) -> dict[str, torch.Tensor]: 76 | """ 77 | Input: 78 | pad_token_map: specifies the padding for each key. Only keys including in this map 79 | will be included in the batch. 80 | """ 81 | # This is a bit ad-hoc to deal with 3d elements, but it works 82 | for e in batch: 83 | for k, v in e.items(): 84 | if k in pad_token_map and isinstance(v[0], list): 85 | _pad_last_dim(v, pad_token_map[k], padding_side) 86 | 87 | batch = [{k: np.array(v) for k, v in e.items() if k in pad_token_map} for e in batch] 88 | max_shapes = _find_max_shapes( 89 | batch, pad_token_map.keys(), pad_to_multiples_of_8=pad_to_multiples_of_8 90 | ) 91 | for i, e in enumerate(batch): 92 | batch[i] = { 93 | k: _pad(e[k], pad_token, max_shapes[k] - np.array(e[k].shape), padding_side) 94 | for k, pad_token in pad_token_map.items() 95 | } 96 | batch[i] = {k: _tensorize(v, k) for k, v in batch[i].items()} 97 | return default_collate(batch) 98 | 99 | 100 | def md5(s): 101 | return hashlib.md5(s.encode("utf-8")).hexdigest() 102 | -------------------------------------------------------------------------------- /better_promptability/data/mixer_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | import random 5 | from typing import Callable 6 | 7 | import torch.distributed as dist 8 | from torch.utils.data.dataloader import DataLoader, _BaseDataLoaderIter 9 | 10 | 11 | class MixerDataLoader(DataLoader): 12 | """ 13 | A dataloader that encapsulates multiple dataloaders. At each iteration, yields the next batch 14 | from a random dataloader. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | dataloaders: list[DataLoader], 20 | meta_batch_size: int, 21 | batch_postprocessor: Callable[[list], list] = lambda b: b, 22 | ): 23 | self._dataloader_iters = [iter(dataloader) for dataloader in dataloaders] 24 | self._meta_batch_size = self._meta_batch_size_per_device = meta_batch_size 25 | self._batch_postprocessor = batch_postprocessor 26 | if dist.is_initialized(): 27 | self._world_size = dist.get_world_size() 28 | self._rank = dist.get_rank() 29 | assert self._meta_batch_size % self._world_size == 0 30 | self._meta_batch_size_per_device = self._meta_batch_size // self._world_size 31 | 32 | num_batches = sum(len(dataloader) for dataloader in dataloaders) 33 | if dist.is_initialized(): 34 | self._total_len = num_batches // meta_batch_size 35 | if num_batches % meta_batch_size > self._rank: 36 | # Some GPUs have one more batch, depending on the number of samples in the final 37 | # batch. 38 | self._total_len += 1 39 | else: 40 | self._total_len = int(math.ceil(num_batches / meta_batch_size)) 41 | self._weights = [len(dataloader) for dataloader in dataloaders] 42 | self._seed = 1 43 | 44 | self.num_workers = 0 # TODO: multiprocessing 45 | self.collate_fn = None 46 | self.dataset = None 47 | 48 | def sample_one_batch(self): 49 | dataloader_idx = random.choices(range(len(self._dataloader_iters)), self._weights)[0] 50 | self._weights[dataloader_idx] -= 1 51 | assert all(w >= 0 for w in self._weights) 52 | dataloader_iter = self._dataloader_iters[dataloader_idx] 53 | return next(dataloader_iter) 54 | 55 | def __iter__(self) -> _BaseDataLoaderIter: 56 | while True: 57 | batches = [] 58 | for _ in range(self._meta_batch_size_per_device): 59 | if dist.is_initialized(): 60 | # For every GPU, we sample the same WORLD_SIZE samples (achieved by temporarily 61 | # syncing the rng state), and give each GPU the sample whose index is the same 62 | # as its rank. Technically we only need to increment the seed at the end of an 63 | # epoch, but there's no harm in doing it more often. 64 | rngstate = random.getstate() 65 | self._seed += 1 66 | random.seed(self._seed) 67 | for i in range(min(self._world_size, sum(self._weights))): 68 | sample = self.sample_one_batch() 69 | if i == self._rank: 70 | batches.append(sample) 71 | random.setstate(rngstate) 72 | else: 73 | batches.append(self.sample_one_batch()) 74 | 75 | if all(w == 0 for w in self._weights): # early stopping 76 | if len(batches) > 0: 77 | yield self._batch_postprocessor(batches) 78 | return 79 | assert len(batches) > 0 80 | yield self._batch_postprocessor(batches) 81 | 82 | def __len__(self) -> int: 83 | return self._total_len 84 | -------------------------------------------------------------------------------- /better_promptability/data/mixer_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import random 4 | from typing import Any, Optional, Union 5 | 6 | from datasets import Dataset as HFDataset 7 | from torch.utils.data import Dataset 8 | from tango.common import Tqdm 9 | 10 | 11 | class MixerDataset(Dataset): 12 | """ 13 | This dataset mixes multiple other datasets into a single :class:`Dataset`. 14 | 15 | The `sampling_cap` argument sets an artificial size limit for all of the datasets which 16 | controls the sampling probability for each. This is useful when you have a mix of small 17 | and large datasets. When using `sampling_cap`, you should call :meth:`resample()` after every 18 | epoch to randomize the examples that get picked from the undersampled datasets, i.e. the datasets 19 | that are bigger than `sampling_cap`. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | datasets: list[HFDataset], 25 | sampling_cap: Optional[int] = None, 26 | seed: int = 3, # this is important during distributed training 27 | no_resample: bool = False, # useful for validation 28 | ): 29 | self._datasets: list[Union[Dataset, HFDataset]] = [] 30 | self._total_size: int = 0 31 | self._no_resample = no_resample 32 | for dataset in Tqdm.tqdm(datasets, desc="Mixing datasets"): 33 | if sampling_cap is not None and len(dataset) > sampling_cap: 34 | self._total_size += sampling_cap 35 | self._datasets.append(_UndersampledDataset(dataset, sampling_cap, seed=seed)) 36 | else: 37 | self._total_size += len(dataset) 38 | self._datasets.append(dataset) 39 | 40 | def __getitem__(self, key: int) -> Any: # type: ignore[override] 41 | for dataset in self._datasets: 42 | if key < len(dataset): 43 | return dataset[key] 44 | key -= len(dataset) 45 | raise IndexError("index out of bounds") 46 | 47 | def __len__(self) -> int: 48 | return self._total_size 49 | 50 | def get_all_example_lens(self) -> list[int]: 51 | lens = [] 52 | for dataset in Tqdm.tqdm(self._datasets, desc="Getting lengths for sampler"): 53 | if isinstance(dataset, HFDataset): 54 | lens.extend(dataset["sort_key_len"]) 55 | elif isinstance(dataset, _UndersampledDataset): 56 | lens.extend(dataset.get_active_example_lens()) 57 | else: 58 | assert False 59 | return lens 60 | 61 | def resample(self): 62 | if self._no_resample: 63 | return 64 | 65 | for dataset in self._datasets: 66 | if isinstance(dataset, _UndersampledDataset): 67 | dataset.resample() 68 | 69 | 70 | class _UndersampledDataset(Dataset): 71 | def __init__( 72 | self, 73 | dataset: HFDataset, 74 | sampling_cap: int, 75 | seed: int = 3, 76 | ): 77 | assert sampling_cap < len(dataset) 78 | self._dataset = dataset 79 | self._sampling_cap = sampling_cap 80 | self._indices = list(range(len(self._dataset))) 81 | self._num_taken = sampling_cap 82 | self._seed = seed 83 | 84 | # It's important that we can shuffle deterministically in order to guarantee 85 | # that different processes shuffle the data in exactly the same way during distributed 86 | # data parallel training, so we always set the seed before shuffling in this class. 87 | # However, we don't want to mess with the RNG state outside of this class, so 88 | # we make sure to reset it right after we shuffle. 89 | state = random.getstate() 90 | random.seed(self._seed) 91 | random.shuffle(self._indices) 92 | random.setstate(state) 93 | 94 | def __getitem__(self, i: int) -> Any: # type: ignore[override] 95 | if i > self._sampling_cap: 96 | raise IndexError("index out of bounds") 97 | return self._dataset[self._indices[i]] 98 | 99 | def __len__(self) -> int: 100 | return self._sampling_cap 101 | 102 | def get_active_example_lens(self) -> list[int]: 103 | return self._dataset.select(self._indices[: self._sampling_cap])["sort_key_len"] 104 | 105 | def resample(self): 106 | self._seed += 1 107 | state = random.getstate() 108 | random.seed(self._seed) 109 | if self._num_taken + self._sampling_cap <= len(self._dataset): 110 | # Re-organize `self._indices` so that the latest used chunk is pulled off and put on the end. 111 | self._indices = ( 112 | self._indices[self._sampling_cap :] + self._indices[: self._sampling_cap] 113 | ) 114 | self._num_taken += self._sampling_cap 115 | else: 116 | # Re-shuffle `self._indices` in a way that ensures the last chunk we have got to is 117 | # used next. 118 | used = ( 119 | self._indices[: self._sampling_cap] 120 | + self._indices[self._sampling_cap + (len(self._dataset) - self._num_taken) :] 121 | ) 122 | unused = self._indices[ 123 | self._sampling_cap : self._sampling_cap + (len(self._dataset) - self._num_taken) 124 | ] 125 | # `used` will be sliced up and moved around before being added back into `self._indices`, 126 | # so we shuffle it now to add randomness. 127 | random.shuffle(used) 128 | 129 | # `next_up` is the next chunk of `self._sampling_cap` which will include all 130 | # of `unused` and however many examples from `used` that we need to reach 131 | # `self._sampling_cap` instances. 132 | next_up = unused + used[: self._sampling_cap - len(unused)] 133 | random.shuffle(next_up) 134 | 135 | # Put everything back together into `self._indices`. 136 | self._indices = next_up + used[self._sampling_cap - len(unused) :] 137 | 138 | # clean up to hopefully help GC 139 | del used, unused, next_up 140 | 141 | self._num_taken = self._sampling_cap 142 | random.setstate(state) 143 | 144 | def fast_forward(self, num_epochs): 145 | # Technically we can manipulate self._seed, self._indices, and self._num_taken directly, 146 | # but this is easier and I think not much slower 147 | for _ in range(num_epochs): 148 | self.resample() 149 | -------------------------------------------------------------------------------- /better_promptability/data/prompt_data_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Mapping 3 | from urllib.error import HTTPError 4 | 5 | from tango.common.aliases import PathOrStr 6 | from transformers import T5Tokenizer, PreTrainedTokenizerBase 7 | 8 | from .data_utils import PAD_TYPE 9 | from .data_module import DataModule 10 | from .config import Config 11 | 12 | 13 | class PromptDataModule(DataModule): 14 | def __init__( 15 | self, 16 | config: Config, 17 | num_prefix: int, 18 | transformer_model: PathOrStr, 19 | deep: bool = False, 20 | **kwargs, 21 | ): 22 | self.num_prefix = num_prefix 23 | self.transformer_model = transformer_model 24 | self.deep = deep 25 | 26 | if not self.deep: 27 | self.task_tokens = ["".format(str(i).zfill(2)) for i in range(self.num_prefix)] 28 | 29 | super().__init__(config, **kwargs) 30 | 31 | self.inputs_max_length = 768 32 | self.targets_max_length = 192 33 | 34 | @property 35 | def hash_fields(self) -> list[Any]: 36 | return super().hash_fields + [ 37 | self.num_prefix, 38 | self.deep, 39 | self.inputs_max_length, 40 | self.targets_max_length, 41 | ] 42 | 43 | def setup_tokenizer(self, retry=10) -> PreTrainedTokenizerBase: 44 | while True: 45 | try: 46 | tokenizer = T5Tokenizer.from_pretrained(self.transformer_model) 47 | break 48 | except HTTPError as e: 49 | if retry == 0: 50 | raise e 51 | retry -= 1 52 | 53 | if not self.deep: 54 | tokenizer.add_tokens(self.task_tokens) 55 | task_token_ids = tokenizer( 56 | " ".join(self.task_tokens), return_tensors="pt", add_special_tokens=False 57 | )["input_ids"] 58 | assert task_token_ids.shape[-1] == self.num_prefix 59 | self.task_token_ids = task_token_ids.squeeze(0).tolist() 60 | 61 | return tokenizer 62 | 63 | def tokenize(self, example: dict[str, Any], split: str): 64 | return NotImplementedError 65 | 66 | def pad_token_map(self, split: str) -> Mapping[str, PAD_TYPE]: # type: ignore 67 | """ 68 | Specifies the padding for each key. Only keys including in this map will be 69 | included in the batch. 70 | """ 71 | pad_token_map_ = { 72 | "input_ids": 0, 73 | "input_mask": False, 74 | "target_ids": 0, 75 | "target_mask": False, 76 | } 77 | return pad_token_map_ 78 | -------------------------------------------------------------------------------- /better_promptability/data/t0_meta_learning_data_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import random 3 | from typing import Optional, Mapping 4 | 5 | from datasets import Dataset as HFDataset 6 | from tango.common import PathOrStr, Tqdm 7 | import torch 8 | import torch.distributed as dist 9 | from torch.utils.data.dataloader import DataLoader 10 | from transformers.trainer_pt_utils import LengthGroupedSampler 11 | 12 | from .config import Config 13 | from .data_utils import collate_fn as default_collate_fn, PAD_TYPE 14 | from .mixer_dataloader import MixerDataLoader 15 | from .mixer_dataset import _UndersampledDataset 16 | from .prompt_data_module import PromptDataModule 17 | from .t0_multitask_data_module import T0MultiTaskDataModule 18 | 19 | 20 | def split_batch(meta_batch: list, support_batch_size: int) -> list: 21 | # Because each batch is internally sorted by length, the naive split will cause a distributional 22 | # difference. 23 | processed_meta_batch = [] 24 | for batch in meta_batch: 25 | batch_size = len(list(batch.values())[0]) 26 | assert all(len(v) == batch_size for v in batch.values()) 27 | support_indices = random.sample(range(batch_size), support_batch_size) 28 | support_indices_set = set(support_indices) 29 | query_indices = [i for i in range(batch_size) if i not in support_indices_set] 30 | 31 | support_batch = {k: v[support_indices] for k, v in batch.items()} 32 | query_batch = {k: v[query_indices] for k, v in batch.items()} 33 | processed_meta_batch.append((support_batch, query_batch)) 34 | return processed_meta_batch 35 | 36 | 37 | @PromptDataModule.register("t0_meta_learning") 38 | class T0MetaLearningDataModule(T0MultiTaskDataModule): 39 | def __init__( 40 | self, 41 | meta_batch_size: int, 42 | support_batch_size: int, 43 | mixture_name: str, # should be 'd4_train', 'd4_dev', or 'green'. 44 | config: Config, 45 | num_prefix: int, 46 | transformer_model: PathOrStr, 47 | sampling_cap: Optional[int] = 500000, 48 | **kwargs 49 | ): 50 | self.meta_batch_size = meta_batch_size 51 | self._meta_batch_size_per_device = self.meta_batch_size // ( 52 | dist.get_world_size() if dist.is_initialized() else 1 53 | ) 54 | self.support_batch_size = support_batch_size 55 | super().__init__( 56 | mixture_name, config, num_prefix, transformer_model, sampling_cap=sampling_cap, **kwargs 57 | ) 58 | 59 | def collate_fn( 60 | self, batch: list[dict[str, list]], pad_token_map: Mapping[str, PAD_TYPE], padding_side: str 61 | ) -> list[tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]]: 62 | batch = [ 63 | default_collate_fn(batch[i : i + self.real_batch_size], pad_token_map, padding_side) 64 | for i in range(0, len(batch), self.real_batch_size) 65 | ] 66 | if len(batch[-1]["input_ids"]) < self.real_batch_size: 67 | batch = batch[:-1] 68 | return split_batch(batch, self.support_batch_size) 69 | 70 | def dataloader( 71 | self, split: str, batch_size: int, collate_fn=default_collate_fn 72 | ) -> DataLoader: 73 | if split != "train": 74 | return super().dataloader(split, batch_size) 75 | 76 | dataset_split = self.dataset_dict[split] 77 | pad_token_map = self.pad_token_map(split) 78 | assert all(pad is not None for pad in pad_token_map.values()) 79 | 80 | dataloaders = [] 81 | for dataset in Tqdm.tqdm(dataset_split._datasets, desc="Creating dataloaders"): 82 | # zhaofeng: I don't particularly like this design because of the redundancy with 83 | # DataModule. But this is necessary at least to accomodate _UndersampledDataset at the 84 | # moment, unless we can somehow turn it into a DataModule too. 85 | if isinstance(dataset, HFDataset): 86 | lens = dataset["sort_key_len"] 87 | elif isinstance(dataset, _UndersampledDataset): 88 | lens = dataset.get_active_example_lens() 89 | else: 90 | assert False 91 | # LengthGroupedSampler sorts from longest to shortest; we want the reverse 92 | lens = [-l for l in lens] # noqa: E741 93 | # It's important we don't used the distributed sampler here since distributed logic 94 | # is handled in MixerDataloader 95 | sampler = LengthGroupedSampler(batch_size, lengths=lens) 96 | dataloader = DataLoader( 97 | dataset, 98 | batch_size=batch_size, 99 | shuffle=False, 100 | sampler=sampler, 101 | num_workers=0, # avoid too many open files error 102 | collate_fn=lambda batch: collate_fn( 103 | batch, pad_token_map, self.tokenizer.padding_side 104 | ), 105 | pin_memory=True, 106 | drop_last=True, # division into support/query is unclear with incomplete batches 107 | ) 108 | dataloaders.append(dataloader) 109 | 110 | return MixerDataLoader( 111 | dataloaders, 112 | self.meta_batch_size, 113 | batch_postprocessor=lambda b: split_batch(b, self.support_batch_size), 114 | ) 115 | -------------------------------------------------------------------------------- /better_promptability/data/t0_mixture.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Mapping, Optional 3 | 4 | from tango.common import PathOrStr, Params 5 | 6 | from .config import Config 7 | from .t0_module import T0Module 8 | 9 | 10 | class T0Mixture: 11 | """ 12 | This class is used to initialize a collection of T0Module. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | mixture_name: str, # should be "d4_train", "d4_dev", or "green" 18 | config: Config, 19 | num_prefix: int, 20 | transformer_model: PathOrStr, 21 | t0_data_cache: PathOrStr, 22 | subsample_indices_file: Optional[str] = None, 23 | **data_module_kwargs, 24 | ): 25 | assert mixture_name in {"d4_train", "d4_dev", "green"} 26 | self.mixture_name = mixture_name 27 | self.data_modules: dict[str, T0Module] = {} 28 | for task_name in Params.from_file("configs/t0_mixtures.jsonnet")[mixture_name]: 29 | self.data_modules[task_name] = T0Module( 30 | config=config, 31 | num_prefix=num_prefix, 32 | transformer_model=transformer_model, 33 | mixture_name=self.mixture_name, 34 | task_name=task_name, 35 | t0_data_cache=t0_data_cache, 36 | subsample_indices_file=subsample_indices_file, 37 | **data_module_kwargs, 38 | ) 39 | assert len(self.data_modules) > 0 40 | -------------------------------------------------------------------------------- /better_promptability/data/t0_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, List, Mapping, Optional 3 | from pathlib import Path 4 | import pickle 5 | 6 | from allennlp.training.metrics import Metric 7 | import numpy as np 8 | from tango.common import Params, PathOrStr 9 | import datasets 10 | 11 | from .data_module import DatasetDictType 12 | from .data_utils import md5, PAD_TYPE 13 | from .prompt_data_module import PromptDataModule 14 | from .config import Config 15 | 16 | 17 | def read_task_info() -> dict[str, tuple[str, Optional[str], str]]: 18 | task_name_to_info: dict[str, tuple[str, Optional[str], str]] = {} 19 | for task_name, info in ( 20 | Params.from_file("configs/t0_task_info.jsonnet").as_dict(quiet=True)["tasks"].items() 21 | ): 22 | task_name_to_info[task_name] = ( 23 | info["dataset_name"], 24 | info["subset_name"], 25 | info["template_name"], 26 | ) 27 | return task_name_to_info 28 | 29 | 30 | @PromptDataModule.register("t0", exist_ok=True) 31 | class T0Module(PromptDataModule): 32 | """ 33 | Represents a single dataset AND template, but all the splits. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | config: Config, 39 | num_prefix: int, 40 | transformer_model: PathOrStr, 41 | mixture_name: str, 42 | task_name: str, 43 | t0_data_cache: PathOrStr, 44 | subsample_indices_file: Optional[str] = None, 45 | **kwargs, 46 | ): 47 | super().__init__(config, num_prefix, transformer_model, **kwargs) 48 | 49 | self.mixture_name = mixture_name 50 | self.task_name = task_name 51 | self.dataset_name, self.subset_name, self.template_name = read_task_info()[self.task_name] 52 | self.t0_data_cache = Path(t0_data_cache) 53 | self.subsample_indices = None 54 | if subsample_indices_file is not None: 55 | self.subsample_indices = pickle.load(open(subsample_indices_file, "rb"))[task_name] 56 | 57 | @property 58 | def hash_fields(self) -> list[Any]: 59 | return super().hash_fields + [self.task_name] 60 | 61 | def setup(self, stage: Optional[str] = None): 62 | super().setup(stage) 63 | if self.subsample_indices is not None: 64 | indices, checksum = self.subsample_indices 65 | dataset = self.dataset_dict[self.train_split].select(indices) 66 | assert md5("".join(str(ex["inputs"] + ex["targets"]) for ex in dataset)) == checksum 67 | self.dataset_dict[self.train_split] = dataset 68 | 69 | @property 70 | def dev_splits(self) -> list[str]: 71 | # d4_dev and green datasets should have dev splits, d4_train may not. 72 | if ( 73 | self.mixture_name in {"d4_dev", "green"} 74 | or "dev" in self.dataset_dict 75 | ): 76 | return ["dev"] 77 | return [] 78 | 79 | @property 80 | def test_splits(self) -> list[str]: 81 | # We don't need the test sets. The test set labels of some datasets are hidden 82 | # (e.g., superglue), and T0 only evaluated on the dev sets. 83 | return [] 84 | 85 | @property 86 | def metric_names(self) -> list[str]: 87 | # For all the green (i.e., d4_score_eval) datasets, all tasks have accuracy as the metric. 88 | return ["categorical_accuracy"] 89 | 90 | @property 91 | def metric_watch_mode(self) -> str: 92 | return "max" 93 | 94 | @property 95 | def sort_key(self) -> str: 96 | return "inputs" 97 | 98 | def load(self) -> DatasetDictType: 99 | data_path = self.t0_data_cache / self.task_name 100 | assert data_path.is_dir() 101 | 102 | dataset_dict = datasets.load_from_disk(data_path) 103 | 104 | # See comment in test_splits(), above 105 | dataset_dict.pop("test", None) 106 | 107 | return dataset_dict 108 | 109 | def tokenize(self, example: dict[str, Any], split: str) -> dict[str, Any]: 110 | inputs = example["inputs"][: self.inputs_max_length] 111 | 112 | # Make sure there are no other EOS in `inputs` and `targets`. 113 | # The EOS token is really the only special token we are concerned about with T5. 114 | # T5 has no BOS token. There might be UNK tokens in the inputs though, but that's okay. 115 | assert self.tokenizer.eos_token_id not in inputs 116 | 117 | single_target: bool = False 118 | is_correct: Optional[List[bool]] = None 119 | targets = example["targets"] 120 | 121 | if self.mixture_name == "d4_train": 122 | single_target = True 123 | elif self.mixture_name == "d4_dev" and split == self.train_split: 124 | single_target = True 125 | 126 | # We want to evaluate d4_dev datasets same way as the green ones. 127 | # Some d4_dev datasets do not have answer_choices at all 128 | # (eg. "web_questions_get_the_answer" simply wants a knowledge-based answer). 129 | # We ignore these datasets. 130 | 131 | elif self.mixture_name == "d4_dev" and split != self.train_split: 132 | single_target = False 133 | # The format in d4_dev is the same as train (there is no is_correct). 134 | # To get multiple targets, we need to use "answer_choices", and tokenize them. 135 | is_correct = [ 136 | choice.strip() == example["targets_pretokenized"].strip() 137 | for choice in example["answer_choices"] 138 | ] 139 | targets = [self.tokenizer(choice)["input_ids"] for choice in example["answer_choices"]] 140 | elif self.mixture_name == "green" and split == self.train_split: 141 | single_target = True 142 | 143 | # Actually getting the single target. 144 | 145 | correct_idx = np.argmax(example["is_correct"]) 146 | targets = targets[correct_idx] 147 | else: 148 | single_target = False 149 | is_correct = example["is_correct"] 150 | 151 | if single_target: 152 | targets = targets[:-1][ # exclude EOS in example['targets'] (we add later) 153 | : self.targets_max_length 154 | ] 155 | assert self.tokenizer.eos_token_id not in targets 156 | input_ids, target_ids, input_mask, target_mask = assemble_prompt( 157 | inputs, 158 | targets, 159 | self.tokenizer.eos_token_id, 160 | self.task_token_ids if not self.deep else [], 161 | ) 162 | else: 163 | input_ids = [] 164 | input_mask = [] 165 | target_mask = [] 166 | target_ids = [] 167 | 168 | for target in targets: 169 | target = target[:-1][ # exclude EOS in example['targets'] (we add later) 170 | : self.targets_max_length 171 | ] 172 | assert self.tokenizer.eos_token_id not in target 173 | 174 | _input_ids, _target_ids, _input_mask, _target_mask = assemble_prompt( 175 | inputs, 176 | target, 177 | self.tokenizer.eos_token_id, 178 | self.task_token_ids if not self.deep else [], 179 | ) 180 | input_ids.append(_input_ids) 181 | input_mask.append(_input_mask) 182 | target_ids.append(_target_ids) 183 | target_mask.append(_target_mask) 184 | 185 | return_dict = { 186 | "input_ids": input_ids, 187 | "input_mask": input_mask, 188 | "target_ids": target_ids, 189 | "target_mask": target_mask, 190 | "sort_key_len": len(example[self.sort_key]), 191 | } 192 | 193 | if not single_target: 194 | assert is_correct is not None and sum(is_correct) == 1 195 | return_dict["is_correct"] = is_correct 196 | return_dict["is_correct_mask"] = [True] * len(is_correct) 197 | return return_dict 198 | 199 | def pad_token_map(self, split: str) -> Mapping[str, PAD_TYPE]: # type: ignore 200 | """ 201 | Specifies the padding for each key. Only keys including in this map will be 202 | included in the batch. 203 | """ 204 | pad_token_map_ = { 205 | "input_ids": 0, 206 | "input_mask": False, 207 | "target_ids": 0, 208 | "target_mask": False, 209 | } 210 | 211 | if ( 212 | self.mixture_name in {"d4_dev", "green"} 213 | and split != self.train_split 214 | ): 215 | pad_token_map_["is_correct"] = False 216 | pad_token_map_["is_correct_mask"] = False 217 | return pad_token_map_ 218 | 219 | 220 | def assemble_prompt(inputs, targets, eos_token_id, task_token_ids): 221 | input_ids = task_token_ids + inputs + [eos_token_id] 222 | target_ids = targets + [eos_token_id] 223 | input_mask = [True] * len(input_ids) 224 | target_mask = [True] * len(target_ids) 225 | return input_ids, target_ids, input_mask, target_mask 226 | -------------------------------------------------------------------------------- /better_promptability/data/t0_multitask_data_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Optional, Mapping, Any 3 | 4 | from tango.common import Tqdm, DatasetDict, PathOrStr 5 | 6 | from .data_utils import PAD_TYPE 7 | from .config import Config 8 | from .mixer_dataset import MixerDataset, _UndersampledDataset 9 | from .prompt_data_module import PromptDataModule 10 | from .t0_mixture import T0Mixture 11 | 12 | 13 | @PromptDataModule.register("t0_multitask") 14 | class T0MultiTaskDataModule(PromptDataModule): 15 | def __init__( 16 | self, 17 | mixture_name: str, # should be 'd4_train', 'd4_dev', or 'green'. 18 | config: Config, 19 | num_prefix: int, 20 | transformer_model: PathOrStr, 21 | t0_data_cache: PathOrStr, 22 | sampling_cap: Optional[int] = 500000, 23 | dev_sampling_cap: Optional[int] = 400, 24 | **kwargs, 25 | ): 26 | super().__init__(config, num_prefix, transformer_model, preprocess_and_save=False, **kwargs) 27 | self.mixture_name = mixture_name 28 | self.t0_mixture = T0Mixture( 29 | mixture_name, 30 | config, 31 | num_prefix, 32 | transformer_model, 33 | t0_data_cache=t0_data_cache, 34 | **kwargs, 35 | ) 36 | self.sampling_cap = sampling_cap 37 | self.dev_sampling_cap = dev_sampling_cap 38 | 39 | @property 40 | def hash_fields(self) -> list[Any]: 41 | return super().hash_fields + [ 42 | self.mixture_name, 43 | self.sampling_cap, 44 | self.dev_sampling_cap, 45 | ] 46 | 47 | @property 48 | def dev_splits(self) -> list[str]: 49 | return ["dev"] 50 | 51 | @property 52 | def test_splits(self) -> list[str]: 53 | # We don't need the test sets. The test set labels of some datasets are hidden 54 | # (e.g., superglue), and T0 only evaluated on the dev sets. 55 | return [] 56 | 57 | @property 58 | def metric_names(self) -> list[str]: 59 | return ["categorical_accuracy"] 60 | 61 | @property 62 | def metric_watch_mode(self) -> str: 63 | return "max" 64 | 65 | @property 66 | def sort_key(self) -> str: 67 | return "inputs" 68 | 69 | def pad_token_map(self, split: str) -> Mapping[str, PAD_TYPE]: # type: ignore 70 | pad_token_map_ = { 71 | "input_ids": 0, 72 | "input_mask": False, 73 | "target_ids": 0, 74 | "target_mask": False, 75 | } 76 | 77 | if ( 78 | self.mixture_name in {"d4_dev", "green"} 79 | and split != self.train_split 80 | ): 81 | pad_token_map_["is_correct"] = False 82 | pad_token_map_["is_correct_mask"] = False 83 | return pad_token_map_ 84 | 85 | def load(self) -> DatasetDict: 86 | with Tqdm.tqdm(self.t0_mixture.data_modules.items(), "Loading T0 datasets") as dm_iter: 87 | for name, data_module in dm_iter: 88 | dm_iter.set_postfix({"module": name if len(name) < 30 else (name[:27] + "...")}) 89 | data_module.tokenizer = self.tokenizer 90 | assert data_module.deep == self.deep 91 | if not self.deep: 92 | data_module.task_token_ids = self.task_token_ids 93 | data_module.setup() 94 | 95 | return DatasetDict( 96 | splits={ 97 | "train": MixerDataset( 98 | [dm[dm.train_split] for dm in self.t0_mixture.data_modules.values()], 99 | sampling_cap=self.sampling_cap, 100 | ), 101 | "dev": MixerDataset( 102 | [ 103 | dm[dm.dev_splits[0]] 104 | for dm in self.t0_mixture.data_modules.values() 105 | if len(dm.dev_splits) > 0 106 | ], 107 | sampling_cap=self.dev_sampling_cap, 108 | no_resample=True, 109 | ), 110 | } 111 | ) 112 | 113 | def on_load_checkpoint(self, checkpoint: dict[str, Any]): 114 | epochs_elapsed = checkpoint["epoch"] # verified that this is 1-based, so we're good 115 | assert self.dataset_dict is not None # loaded already 116 | for mixer_dataset in self.dataset_dict.values(): 117 | assert isinstance(mixer_dataset, MixerDataset) 118 | for dataset in mixer_dataset._datasets: 119 | if isinstance(dataset, _UndersampledDataset): 120 | dataset.fast_forward(epochs_elapsed) 121 | 122 | super().on_load_checkpoint(checkpoint) 123 | -------------------------------------------------------------------------------- /better_promptability/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/better_promptability/models/__init__.py -------------------------------------------------------------------------------- /better_promptability/models/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | 5 | from allennlp.training.metrics import Metric 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from tango.common.lazy import Lazy 10 | from tango.integrations.pytorch_lightning.model import LightningModule 11 | from tango.integrations.torch.optim import Optimizer 12 | 13 | from ..data.config import Config 14 | from ..data.data_module import DataModule 15 | 16 | 17 | class Model(LightningModule): 18 | def __init__( 19 | self, 20 | config: Config, 21 | dataset: DataModule, 22 | optimizer: Optional[Lazy[Optimizer]] = None, 23 | epochs: int = 3, 24 | weight_decay: float = 0.0, 25 | accumulate_grad_batches: int = 1, 26 | warmup_steps: int = 0, 27 | ): 28 | super().__init__() 29 | 30 | self.config = config 31 | self.dataset = dataset 32 | self._optimizer = optimizer 33 | if self._optimizer is not None: 34 | assert isinstance(self._optimizer, Lazy) 35 | 36 | self.epochs = epochs 37 | self.optimizer_kwargs = { 38 | "weight_decay": weight_decay, 39 | "accumulate_grad_batches": accumulate_grad_batches, 40 | "warmup_steps": warmup_steps, 41 | } 42 | 43 | self.metrics = self.setup_metrics() 44 | 45 | def setup(self, stage: str = None): 46 | """To set up self.dataset_size""" 47 | if stage != "fit": 48 | return 49 | self.dataset_size = len(self.dataset.dataset_dict[self.dataset.train_split]) 50 | 51 | def setup_metrics(self) -> Dict[str, Dict[str, Metric]]: 52 | return { 53 | split: { 54 | name: self.dataset.instantiate_metric(name, split) 55 | for name in self.dataset.metric_names 56 | } 57 | for split in self.dataset.dev_splits + self.dataset.test_splits 58 | } 59 | 60 | def configure_optimizers(self) -> Union[List[Optimizer], Tuple[List[Optimizer], List[Dict]]]: 61 | """Prepare optimizer and schedule (linear warmup and decay)""" 62 | assert self._optimizer is not None 63 | 64 | no_decay = ["bias", "LayerNorm.weight", "layernorm.weight", "layer_norm.weight"] 65 | optimizer_grouped_parameters = [ 66 | { 67 | "params": [ 68 | p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) 69 | ], 70 | "weight_decay": self.optimizer_kwargs["weight_decay"], 71 | }, 72 | { 73 | "params": [ 74 | p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) 75 | ], 76 | "weight_decay": 0.0, 77 | }, 78 | ] 79 | 80 | optimizer = self._optimizer.construct(params=optimizer_grouped_parameters) # type: ignore 81 | 82 | return [optimizer] 83 | 84 | def optimizer_zero_grad( 85 | self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int 86 | ): 87 | """See https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html""" 88 | optimizer.zero_grad() 89 | 90 | def compute_loss( 91 | self, 92 | logits: torch.Tensor, 93 | labels: torch.Tensor, 94 | mask: Optional[torch.Tensor] = None, 95 | reduce=True, 96 | ) -> torch.Tensor: 97 | assert mask is not None 98 | loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), reduction="none") 99 | loss = loss.view_as(labels) * mask 100 | if reduce: 101 | assert mask.any(dim=-1).all() 102 | loss = loss.sum() / mask.sum() # type: ignore 103 | return loss 104 | 105 | def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> dict[str, Any]: 106 | loss = self.compute_loss( 107 | self(batch)["logits"], batch["target_ids"], batch.get("target_mask") 108 | ) 109 | self.log("train_loss", loss) 110 | return {"loss": loss} 111 | 112 | def get_predictions(self, logits: torch.Tensor, batch: dict[str, torch.Tensor]) -> torch.Tensor: 113 | return logits.argmax(dim=-1) 114 | 115 | def eval_step( 116 | self, 117 | batch: dict[str, torch.Tensor], 118 | batch_idx: int, 119 | dataloader_idx=0, 120 | compute_loss=True, 121 | ) -> dict[str, Any]: 122 | logits = self(batch)["logits"] 123 | preds = self.get_predictions(logits, batch).masked_fill( 124 | ~batch["is_correct_mask"], torch.finfo(logits.dtype).min 125 | ) 126 | targets = batch["target_ids"] # target sequences. 127 | 128 | if "is_correct" in batch: 129 | labels = (batch["is_correct"] & batch["is_correct_mask"]).byte().argmax(dim=-1) 130 | 131 | split = self.dataset.dev_splits[dataloader_idx] 132 | for metric in self.metrics[split].values(): 133 | metric(*metric.detach_tensors(preds, labels)) 134 | 135 | return ( 136 | {"loss": self.compute_loss(logits, targets, batch.get("targets_mask")).detach().cpu()} 137 | if compute_loss 138 | else {} 139 | ) 140 | 141 | def eval_epoch_end(self, outputs: Union[list[list[dict[str, Any]]], list[dict[str, Any]]]): 142 | # pytorch-lightning "conveniently" unwraps the list when there's only one dataloader, 143 | # so we need a check here. 144 | num_splits = 1 if isinstance(outputs[0], dict) else len(outputs) 145 | 146 | # We gather individual metrics from each dataloader and compute the average if there is 147 | # more than one 148 | if num_splits > 1: 149 | sums: defaultdict = defaultdict(int) 150 | for i in range(num_splits): 151 | split = self.dataset.dev_splits[i] 152 | assert split != "avg" # reserved keyword for below 153 | metrics = self.get_metrics(split, reset=True) 154 | for k, v in metrics.items(): 155 | if num_splits > 1: 156 | self.log(f"{k}_{split}", v) 157 | sums[k] += v 158 | else: 159 | self.log(k, v) 160 | if num_splits > 1: 161 | for k, v in sums.items(): 162 | self.log(f"{k}_avg", v / num_splits) 163 | 164 | def get_metrics(self, split: str, reset=False) -> dict[str, Any]: 165 | metrics = {name: metric.get_metric() for name, metric in self.metrics[split].items()} 166 | if reset: 167 | for metric in self.metrics[split].values(): 168 | metric.reset() 169 | return metrics 170 | 171 | def validation_step( 172 | self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx=0 173 | ) -> dict[str, Any]: 174 | return self.eval_step(batch, batch_idx, dataloader_idx=dataloader_idx) 175 | 176 | def validation_epoch_end(self, outputs: list[dict[str, Any]]): 177 | return self.eval_epoch_end(outputs) 178 | 179 | def test_step( 180 | self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx=0 181 | ) -> dict[str, Any]: 182 | return self.eval_step(batch, batch_idx, dataloader_idx=dataloader_idx) 183 | 184 | def test_epoch_end(self, outputs: list[dict[str, Any]]): 185 | return self.eval_epoch_end(outputs) 186 | -------------------------------------------------------------------------------- /better_promptability/models/prefix_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import logging 3 | from typing import Any, Callable, IO, Optional, Union, Dict 4 | 5 | import torch 6 | from tango.common.lazy import Lazy 7 | from tango.integrations.torch.optim import Optimizer 8 | from transformers import T5ForConditionalGeneration 9 | 10 | from ..data.config import Config 11 | from ..data.prompt_data_module import PromptDataModule 12 | from ..data.t0_multitask_data_module import T0MultiTaskDataModule 13 | from ..modules.transformer import Transformer 14 | from ..modules.with_prefix_embedding import WithPrefixEmbedding 15 | from .model import Model 16 | from .t5_with_prefix import T5WithPrefixConfig, T5ForConditionalGenerationWithPrefix 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @Model.register("prefix_transformer") 22 | @Model.register("prefix_transformer_from_checkpoint", constructor="load_from_checkpoint") 23 | class PrefixTransformer(Model): 24 | def __init__( 25 | self, 26 | config: Config, 27 | dataset: PromptDataModule, 28 | transformer_model: str, 29 | optimizer: Optional[Lazy[Optimizer]] = None, 30 | epochs: int = 3, 31 | weight_decay: float = 0.0, 32 | accumulate_grad_batches: int = 1, 33 | warmup_steps: int = 0, 34 | train_full_model: bool = False, 35 | **transformer_kwargs, 36 | ): 37 | self.transformer_name = transformer_model 38 | self.train_full_model = train_full_model 39 | self.deep = dataset.deep 40 | 41 | super().__init__( 42 | config, 43 | dataset, 44 | optimizer=optimizer, 45 | epochs=epochs, 46 | weight_decay=weight_decay, 47 | accumulate_grad_batches=accumulate_grad_batches, 48 | warmup_steps=warmup_steps, 49 | ) 50 | 51 | if not self.deep: 52 | self.transformer = Transformer(transformer_model, "seq2seq-lm", **transformer_kwargs) 53 | else: 54 | self.transformer = Transformer( 55 | transformer_model, 56 | "seq2seq-lm", 57 | config_cls=T5WithPrefixConfig, 58 | model_cls=T5ForConditionalGenerationWithPrefix, 59 | num_prefix=dataset.num_prefix, 60 | **transformer_kwargs, 61 | ) 62 | transformer_model: T5ForConditionalGeneration = self.transformer.model 63 | assert isinstance(transformer_model, T5ForConditionalGeneration) 64 | 65 | if not self.train_full_model: 66 | for n, param in self.transformer.named_parameters(): 67 | if n.startswith("model.encoder.prefix_") or n.startswith("model.decoder.prefix_"): 68 | assert self.deep 69 | else: 70 | param.requires_grad = False 71 | 72 | if not self.deep: 73 | transformer_model.set_input_embeddings( 74 | WithPrefixEmbedding( 75 | transformer_model.shared, 76 | self.dataset.tokenizer.vocab_size, 77 | self.dataset.num_prefix, 78 | ) 79 | ) 80 | 81 | def unfreeze(self) -> dict[torch.nn.Parameter, bool]: 82 | orig_requires_grad = {} 83 | for param in self.transformer.parameters(): 84 | orig_requires_grad[param] = param.requires_grad 85 | param.requires_grad = True 86 | return orig_requires_grad 87 | 88 | def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 89 | input_ids = batch["input_ids"] 90 | input_mask = batch["input_mask"] 91 | target_ids = batch["target_ids"] 92 | target_mask = batch["target_mask"] 93 | 94 | return_dict = {} 95 | 96 | assert input_ids.shape == input_mask.shape and input_ids.dim() in (2, 3) 97 | if not self.training: # for inference we have an additional dimension for classes 98 | orig_shape = input_ids.shape # bs x num_classes x seq_len 99 | input_ids = input_ids.reshape(-1, orig_shape[-1]) 100 | input_mask = input_mask.reshape(-1, orig_shape[-1]) 101 | 102 | orig_decoder_shape = target_ids.shape 103 | target_ids = target_ids.reshape(-1, orig_decoder_shape[-1]) 104 | target_mask = target_mask.reshape(-1, orig_decoder_shape[-1]) 105 | 106 | logits = self.transformer( 107 | input_ids=input_ids, 108 | attention_mask=input_mask, 109 | labels=target_ids, 110 | decoder_attention_mask=target_mask, 111 | ).logits 112 | 113 | if not self.training: 114 | logits = logits.reshape(*(orig_decoder_shape + (-1,))) 115 | return_dict["logits"] = logits 116 | 117 | return return_dict 118 | 119 | def get_predictions(self, logits: torch.Tensor, batch: dict[str, torch.Tensor]) -> torch.Tensor: 120 | """ 121 | Input: 122 | logits: (bsz, num_classes, seq_len, vocab_size) 123 | Output: 124 | scores: (bsz, num_classes) 125 | """ 126 | mask = batch["target_mask"] # (bsz, num_classes, seq_len) 127 | loss = self.compute_loss(logits, batch["target_ids"], mask, reduce=False) 128 | scores = -loss.sum(-1) / (mask.sum(-1) + 1e-6) # already masked in compute_loss() 129 | return scores 130 | 131 | def eval_step( 132 | self, 133 | batch: dict[str, torch.Tensor], 134 | batch_idx: int, 135 | dataloader_idx=0, 136 | compute_loss=True, 137 | ) -> dict[str, Any]: 138 | if isinstance(self.dataset, T0MultiTaskDataModule): 139 | preds = self(batch)["logits"] 140 | split = self.dataset.dev_splits[dataloader_idx] 141 | for metric in self.metrics[split].values(): 142 | metric(*metric.detach_tensors(preds, batch["target_ids"], batch["target_mask"])) 143 | return {} 144 | else: 145 | return super().eval_step( 146 | batch, batch_idx, dataloader_idx=dataloader_idx, compute_loss=False 147 | ) 148 | 149 | def on_save_checkpoint(self, checkpoint: dict[str, Any]): 150 | """ 151 | PyTorch's native optimizer state checkpoint logic is very fragile, so we also do it on our 152 | own. See https://github.com/pytorch/pytorch/issues/1489 153 | Also, when prompt-tuning, only stores prompt embedding in the checkpoint. 154 | """ 155 | optimizer_states = self.optimizers(use_pl_optimizer=False).state 156 | if not self.train_full_model: 157 | weight_keys = ( 158 | ["transformer.model.shared.new_embed.weight"] 159 | if not self.deep 160 | else [ 161 | k 162 | for k in checkpoint["state_dict"].keys() 163 | if k.startswith("transformer.model.encoder.prefix_") 164 | or k.startswith("transformer.model.decoder.prefix_") 165 | ] 166 | ) 167 | checkpoint["state_dict"] = {k: checkpoint["state_dict"][k] for k in weight_keys} 168 | 169 | name_to_param = {n: p for n, p in self.named_parameters()} 170 | states = {k: optimizer_states[name_to_param[k]] for k in weight_keys} 171 | else: 172 | param_to_name = {p: n for n, p in self.named_parameters()} 173 | states = {param_to_name[p]: states for p, states in optimizer_states.items()} 174 | checkpoint["custom_optimizer_states"] = states 175 | 176 | def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: 177 | if any(k.startswith("model.") for k in checkpoint["state_dict"].keys()): 178 | # Unwrap the meta-learning model 179 | new_state_dict = {} 180 | for k, v in checkpoint["state_dict"].items(): 181 | assert k.startswith("model.") 182 | new_state_dict[k[len("model.") :]] = v 183 | checkpoint["state_dict"] = new_state_dict 184 | # TODO: optimizer states 185 | return super().on_load_checkpoint(checkpoint) 186 | 187 | def meta_learning_copy(self): 188 | new = PrefixTransformer( 189 | self.config, 190 | self.dataset, 191 | self.transformer_name, 192 | optimizer=self._optimizer, 193 | epochs=self.epochs, 194 | weight_decay=self.optimizer_kwargs["weight_decay"], 195 | accumulate_grad_batches=self.optimizer_kwargs["accumulate_grad_batches"], 196 | warmup_steps=self.optimizer_kwargs["warmup_steps"], 197 | train_full_model=self.train_full_model, 198 | deep=self.deep, 199 | ) 200 | new.to(self.device) 201 | new.load_state_dict(self.state_dict()) 202 | return new 203 | 204 | @classmethod 205 | def load_from_checkpoint( 206 | cls, 207 | checkpoint_path: Union[str, IO], 208 | map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, 209 | hparams_file: Optional[str] = None, 210 | strict: bool = True, 211 | optimizer: Optional[Lazy[Optimizer]] = None, 212 | **kwargs, 213 | ): 214 | # We need to tell tango the type of optimizer, or otherwise it will only give us a Params 215 | # object 216 | return super().load_from_checkpoint( 217 | checkpoint_path, 218 | map_location=map_location, 219 | hparams_file=hparams_file, 220 | strict=strict, 221 | optimizer=optimizer, 222 | **kwargs, 223 | ) 224 | -------------------------------------------------------------------------------- /better_promptability/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/better_promptability/modules/__init__.py -------------------------------------------------------------------------------- /better_promptability/modules/transformer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from transformers import ( 5 | AutoConfig, 6 | AutoModel, 7 | AutoModelForPreTraining, 8 | AutoModelForQuestionAnswering, 9 | AutoModelForSeq2SeqLM, 10 | AutoModelForSequenceClassification, 11 | AutoModelForTokenClassification, 12 | AutoModelForCausalLM, 13 | ) 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | TASKS = { 20 | "base": AutoModel, 21 | "sequence-classification": AutoModelForSequenceClassification, 22 | "question-answering": AutoModelForQuestionAnswering, 23 | "pretraining": AutoModelForPreTraining, 24 | "token-classification": AutoModelForTokenClassification, 25 | "causal-lm": AutoModelForCausalLM, 26 | "summarization": AutoModelForSeq2SeqLM, 27 | "translation": AutoModelForSeq2SeqLM, 28 | "seq2seq-lm": AutoModelForSeq2SeqLM, 29 | } 30 | 31 | 32 | class Transformer(torch.nn.Module): 33 | def __init__( 34 | self, 35 | transformer_model: str, 36 | task: str, 37 | trainable=True, 38 | config_cls=AutoConfig, 39 | model_cls=None, 40 | **config_kwargs, 41 | ): 42 | super().__init__() 43 | 44 | config_args = dict(config_kwargs) 45 | if task == "base": # TODO: this might break models that don't support this flag 46 | config_args["add_pooling_layer"] = False 47 | self.config = config_cls.from_pretrained(transformer_model, **config_args) 48 | model_cls = model_cls if model_cls is not None else TASKS[task] 49 | self.model = model_cls.from_pretrained(transformer_model, config=self.config) 50 | 51 | if not trainable: # TODO: support this 52 | assert task == "base", "No support for freezing the backbone for headed tasks yet" 53 | self.trainable = trainable 54 | 55 | def forward(self, *args, **kwargs): 56 | if "attention_mask" in kwargs: # `transformers` doesn't take bool masks which is crazy 57 | kwargs["attention_mask"] = kwargs["attention_mask"].float() 58 | if "decoder_attention_mask" in kwargs: 59 | kwargs["decoder_attention_mask"] = kwargs["decoder_attention_mask"].float() 60 | # If grad was previous disabled (e.g., in eval), don't change it 61 | with torch.set_grad_enabled(torch.is_grad_enabled() and self.trainable): 62 | return self.model(*args, **kwargs) 63 | -------------------------------------------------------------------------------- /better_promptability/modules/with_prefix_embedding.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class WithPrefixEmbedding(nn.Module): 13 | """ 14 | From 15 | https://github.com/shmsw25/Channel-LM-Prompting/blob/cbbb92cc97039c73475ddf0db46896e9efeff3c1/model_util.py#L113 16 | """ 17 | 18 | def __init__(self, orig_embed, expected_vocab_size, n_prefix): 19 | super().__init__() 20 | 21 | self.expected_vocab_size = expected_vocab_size 22 | orig_embed_len = orig_embed.weight.shape[0] 23 | assert expected_vocab_size <= orig_embed_len 24 | if expected_vocab_size < orig_embed_len: 25 | logger.warning( 26 | f"Embedding matrix will be resized from {orig_embed_len} to {expected_vocab_size}. " 27 | "This is expected for at least T5, and maybe some other models too. " 28 | "See https://github.com/huggingface/transformers/issues/4875#issuecomment-997299787" 29 | ) 30 | 31 | self.embed = orig_embed 32 | self.new_embed = nn.Embedding(n_prefix, self.embed.embedding_dim) 33 | 34 | # following Lester et al. 2021 in initializing using the top 5000 random vocabs 35 | indices = np.random.permutation(range(5000))[:n_prefix] 36 | init_weight = self.embed.state_dict()["weight"][indices] 37 | self.new_embed._load_from_state_dict({"weight": init_weight}, "", None, True, [], [], "") 38 | 39 | def forward(self, input): 40 | return F.embedding( 41 | input, 42 | torch.cat([self.embed.weight[: self.expected_vocab_size], self.new_embed.weight], 0), 43 | self.embed.padding_idx, 44 | self.embed.max_norm, 45 | self.embed.norm_type, 46 | self.embed.scale_grad_by_freq, 47 | self.embed.sparse, 48 | ) 49 | -------------------------------------------------------------------------------- /better_promptability/steps/__init__.py: -------------------------------------------------------------------------------- 1 | from .process_dataset import ProcessDataset 2 | -------------------------------------------------------------------------------- /better_promptability/steps/process_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Dict 4 | from datasets import Dataset, DatasetDict 5 | from tango.step import Step 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | @Step.register("process_dataset") 11 | class ProcessDataset(Step): 12 | 13 | DETERMINISTIC: bool = True 14 | CACHEABLE = False # use datasets caching. 15 | 16 | def run( 17 | self, old_data_path: str, new_data_path: str, process_if_exists: bool = False 18 | ) -> DatasetDict: # type: ignore[override] 19 | 20 | if not process_if_exists and os.path.exists(new_data_path): 21 | logger.info( 22 | f"The processed dataset already exists at {new_data_path}. " 23 | "Set `process_if_exists` to `True` if you want to process again. " 24 | "Returning existing dataset." 25 | ) 26 | return DatasetDict.load_from_disk(new_data_path) 27 | 28 | dataset_dict = DatasetDict.load_from_disk(old_data_path) 29 | new_splits = {} 30 | 31 | for split_name in dataset_dict: 32 | split = dataset_dict[split_name] 33 | 34 | new_instances: Dict = { 35 | "inputs": [], 36 | "inputs_pretokenized": [], 37 | "targets": [], 38 | "targets_pretokenized": [], 39 | "is_correct": [], 40 | } 41 | 42 | instance: Dict = { 43 | "inputs": None, 44 | "inputs_pretokenized": None, 45 | "targets": [], 46 | "targets_pretokenized": [], 47 | "is_correct": [], 48 | } 49 | 50 | # TODO: assert for presence of the right keys in the dataset. 51 | for row in split: 52 | if row["idx"][1] == 0 and instance["inputs"] is not None: 53 | new_instances["inputs"].append(instance["inputs"]) 54 | new_instances["inputs_pretokenized"].append(instance["inputs_pretokenized"]) 55 | new_instances["targets"].append(instance["targets"]) 56 | new_instances["targets_pretokenized"].append(instance["targets_pretokenized"]) 57 | new_instances["is_correct"].append(instance["is_correct"]) 58 | 59 | instance = { 60 | "inputs": None, 61 | "inputs_pretokenized": None, 62 | "targets": [], 63 | "targets_pretokenized": [], 64 | "is_correct": [], 65 | } 66 | 67 | instance["inputs"] = row["inputs"] 68 | instance["inputs_pretokenized"] = row["inputs_pretokenized"] 69 | instance["targets"].append(row["targets"]) 70 | instance["targets_pretokenized"].append(row["targets_pretokenized"]) 71 | instance["is_correct"].append(row["is_correct"]) 72 | 73 | new_instances["inputs"].append(instance["inputs"]) 74 | new_instances["inputs_pretokenized"].append(instance["inputs_pretokenized"]) 75 | new_instances["targets"].append(instance["targets"]) 76 | new_instances["targets_pretokenized"].append(instance["targets_pretokenized"]) 77 | new_instances["is_correct"].append(instance["is_correct"]) 78 | 79 | new_splits[split_name] = Dataset.from_dict(new_instances) 80 | 81 | new_dataset_dict: DatasetDict = DatasetDict(new_splits) 82 | logger.info(f"Saving processed dataset at {new_data_path}.") 83 | new_dataset_dict.save_to_disk(new_data_path) 84 | return new_dataset_dict 85 | -------------------------------------------------------------------------------- /better_promptability/steps/process_story_cloze.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Dict 4 | from datasets import Dataset, DatasetDict 5 | from tango.step import Step 6 | 7 | from allennlp.common import cached_transformers 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @Step.register("process_story_cloze") 13 | class ProcessStoryCloze(Step): 14 | 15 | DETERMINISTIC: bool = True 16 | CACHEABLE = False # use datasets caching. 17 | 18 | def run( 19 | self, 20 | old_data_path: str, 21 | new_data_path: str, 22 | process_if_exists: bool = False, 23 | tokenizer_model: str = "google/t5-small-lm-adapt", 24 | ) -> DatasetDict: # type: ignore[override] 25 | 26 | if not process_if_exists and os.path.exists(new_data_path): 27 | logger.info( 28 | f"The processed dataset already exists at {new_data_path}. " 29 | "Set `process_if_exists` to `True` if you want to process again. " 30 | "Returning existing dataset." 31 | ) 32 | return DatasetDict.load_from_disk(new_data_path) 33 | 34 | tokenizer = cached_transformers.get_tokenizer(tokenizer_model) 35 | 36 | dataset_dict = DatasetDict.load_from_disk(old_data_path) 37 | new_splits = {} 38 | 39 | for split_name in dataset_dict: 40 | split = dataset_dict[split_name] 41 | 42 | new_instances: Dict = { 43 | "inputs": [], 44 | "inputs_pretokenized": [], 45 | "targets": [], 46 | "targets_pretokenized": [], 47 | "is_correct": [], 48 | } 49 | 50 | for instance in split: 51 | actual_targets_pretokenized = instance["targets_pretokenized"] 52 | 53 | is_correct = [ 54 | choice.strip() == actual_targets_pretokenized.strip() 55 | for choice in (instance["answer_choices"]) 56 | ] 57 | 58 | targets = [ 59 | tokenizer(choice, add_special_tokens=False)["input_ids"] 60 | for choice in instance["answer_choices"] 61 | ] 62 | 63 | targets_pretokenized = instance["answer_choices"] 64 | 65 | new_instances["inputs"].append(instance["inputs"]) 66 | new_instances["inputs_pretokenized"].append(instance["inputs_pretokenized"]) 67 | new_instances["targets"].append(targets) 68 | new_instances["targets_pretokenized"].append(targets_pretokenized) 69 | new_instances["is_correct"].append(is_correct) 70 | 71 | # Story cloze doesn't have a training set, so we use validation for training and test 72 | # for validation. We in general don't use test sets. 73 | if split_name == "validation": 74 | split_name = "train" 75 | if split_name == "test": 76 | split_name = "validation" 77 | new_splits[split_name] = Dataset.from_dict(new_instances) 78 | 79 | new_dataset_dict: DatasetDict = DatasetDict(new_splits) 80 | logger.info(f"Saving processed dataset at {new_data_path}.") 81 | new_dataset_dict.save_to_disk(new_data_path) 82 | return new_dataset_dict 83 | -------------------------------------------------------------------------------- /better_promptability/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/better_promptability/train/__init__.py -------------------------------------------------------------------------------- /better_promptability/train/aggregate_results.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Any, Dict, List, Tuple, Set, Optional 3 | 4 | import numpy as np 5 | from tango import Format, JsonFormat, Step 6 | from tango.common import Params 7 | import torch 8 | 9 | 10 | @Step.register("aggregate_results") 11 | class AggregateResults(Step): 12 | DETERMINISTIC = True 13 | CACHEABLE = True 14 | FORMAT: Format = JsonFormat() 15 | VERSION = "002" 16 | 17 | def run(self, results: Dict[str, Tuple[str, List[Dict[str, Any]]]]) -> Dict[str, Any]: 18 | """ 19 | Aggregate the results of a bunch of `TrainStep`s. `results` is a mapping of `task_name` 20 | the output from the corresponding `TrainStep`. 21 | """ 22 | t0_task_info = Params.from_file("configs/t0_task_info.jsonnet")["tasks"].as_dict(quiet=True) 23 | 24 | def accuracy_for_task(task_name: str) -> float: 25 | acc = results[task_name][1][-1]["best_categorical_accuracy"] 26 | if isinstance(acc, (float, int)): 27 | return float(acc) 28 | elif isinstance(acc, torch.Tensor): 29 | return acc.item() 30 | else: 31 | raise TypeError(acc) 32 | 33 | def stats_for_tasks(tasks: Set[str]) -> Dict[str, Optional[float]]: 34 | accuracies = [accuracy_for_task(task_name) for task_name in tasks] 35 | return { 36 | "mean": np.mean(accuracies), 37 | "std": None if len(accuracies) <= 1 else np.std(accuracies), 38 | } 39 | 40 | dataset_to_tasks: Dict[str, Set[str]] = defaultdict(set) 41 | dataset_to_subset_to_tasks: Dict[str, Dict[str, Set[str]]] = defaultdict( 42 | lambda: defaultdict(set) 43 | ) 44 | for task_name in results: 45 | dataset_name = t0_task_info[task_name]["dataset_name"] 46 | subset_name = t0_task_info[task_name]["subset_name"] 47 | dataset_to_tasks[dataset_name].add(task_name) 48 | dataset_to_subset_to_tasks[dataset_name][subset_name].add(task_name) 49 | 50 | # For direct copying into a spreadsheet 51 | flattened_results = [] 52 | for dataset_name, subset_to_tasks in dataset_to_subset_to_tasks.items(): 53 | for subset_name in subset_to_tasks: 54 | stats = stats_for_tasks(subset_to_tasks[subset_name]) 55 | flattened_results.extend([stats["mean"], stats["std"]]) 56 | 57 | return { 58 | "categorical_accuracy_all": stats_for_tasks(set(results.keys())), 59 | "categorical_accuracy_by_dataset": { 60 | dataset_name: stats_for_tasks(tasks) 61 | for dataset_name, tasks in dataset_to_tasks.items() 62 | }, 63 | "categorical_accuracy_by_dataset_and_subset": { 64 | dataset_name: { 65 | subset_name: stats_for_tasks(subset_to_tasks[subset_name]) 66 | for subset_name in subset_to_tasks 67 | } 68 | for dataset_name, subset_to_tasks in dataset_to_subset_to_tasks.items() 69 | }, 70 | "flattened": ",".join( 71 | [str(n * 100) if n is not None else "0" for n in flattened_results] 72 | ), 73 | } 74 | -------------------------------------------------------------------------------- /better_promptability/train/eval.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import pytorch_lightning as pl 4 | from tango.common.lazy import Lazy 5 | from tango.integrations.pytorch_lightning import LightningTrainer 6 | from tango.format import JsonFormat 7 | from tango.step import Step 8 | 9 | from ..data.config import Config 10 | from ..data.prompt_data_module import PromptDataModule 11 | from ..models.model import Model 12 | 13 | 14 | @Step.register("eval_step") 15 | class EvalStep(Step): 16 | 17 | DETERMINISTIC: bool = True 18 | CACHEABLE = True 19 | FORMAT = JsonFormat() 20 | 21 | def run( # type: ignore[override] 22 | self, 23 | config: Config, 24 | trainer: Lazy[LightningTrainer], 25 | model: Lazy[Model], 26 | datamodule: Lazy[PromptDataModule], 27 | ) -> Tuple[Optional[str], List[Dict[str, float]]]: 28 | pl.seed_everything(config.seed) 29 | 30 | datamodule = datamodule.construct(config=config) 31 | 32 | datamodule.prepare_data() 33 | datamodule.setup() 34 | 35 | trainer: LightningTrainer = trainer.construct( 36 | work_dir=self.work_dir, 37 | gpus=config.gpus, 38 | accelerator="gpu" if config.gpus else "cpu", 39 | auto_select_gpus=True, 40 | ) 41 | 42 | model = model.construct(config=config, dataset=datamodule) 43 | 44 | output = trainer.test(model, dataloaders=datamodule.val_dataloader()) 45 | 46 | # Make output the same format as TrainStep for results aggregation. 47 | # Maybe it's cleaner to make the aggregation more flexible instead. 48 | assert len(output) == 1 49 | output = [{"best_" + k: v for k, v in output[0].items()}] 50 | 51 | return None, output 52 | -------------------------------------------------------------------------------- /better_promptability/train/optim.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Union 3 | 4 | from transformers.optimization import Adafactor as HFAdafactor 5 | from tango.integrations.torch.optim import Optimizer 6 | 7 | 8 | @Optimizer.register("adafactor") 9 | class Adafactor(HFAdafactor): 10 | """See https://github.com/huggingface/transformers/issues/14830 11 | 12 | Nevertheless, this is only here for backward compatibility, and I suspect technically 13 | you can just use transformers::adafactor in your config. 14 | """ 15 | 16 | @staticmethod 17 | def _get_options(param_group, param_shape, min_dim_size_to_factor=128): 18 | factored, use_first_moment = HFAdafactor._get_options(param_group, param_shape) 19 | if all(d < min_dim_size_to_factor for d in param_shape): 20 | factored = False 21 | return factored, use_first_moment 22 | 23 | 24 | def resolve_optimizer_conf( 25 | opt_conf: Union[list[Optimizer], tuple[list[Optimizer], list[dict]]] 26 | ) -> Optimizer: 27 | """ 28 | Get the optimizer from the lightning's configure_optimizers() output. 29 | """ 30 | if ( 31 | isinstance(opt_conf, (list, tuple)) 32 | and len(opt_conf) == 2 33 | and isinstance(opt_conf[0][0], Optimizer) 34 | ): 35 | # optimizers + schedulers 36 | optimizers = opt_conf[0] 37 | else: 38 | optimizers = opt_conf 39 | assert len(optimizers) == 1 40 | return optimizers[0] 41 | -------------------------------------------------------------------------------- /better_promptability/train/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from pathlib import Path 5 | from typing import Dict, List, Tuple, Optional 6 | 7 | import dill 8 | import pytorch_lightning as pl 9 | import transformers 10 | from pytorch_lightning.plugins import DDPShardedPlugin 11 | from pytorch_lightning.utilities import rank_zero_only 12 | from tango.common.lazy import Lazy 13 | from tango.common.util import get_extra_imported_modules 14 | from tango.integrations.pytorch_lightning import ( 15 | LightningCallback, 16 | LightningModule, 17 | LightningTrainer, 18 | ) 19 | from tango.format import JsonFormat 20 | from tango.integrations.torch import Optimizer 21 | from tango.step import Step 22 | 23 | from better_promptability.data.config import Config 24 | from better_promptability.data.prompt_data_module import PromptDataModule 25 | from better_promptability.data.t0_multitask_data_module import T0MultiTaskDataModule 26 | from better_promptability.models.model import Model 27 | 28 | 29 | Optimizer.register("transformers::adafactor")(transformers.optimization.Adafactor) 30 | 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | @LightningCallback.register("my_logger") 36 | class LoggingCallback(LightningCallback): 37 | def __init__(self): 38 | self.best_epoch = None 39 | self.best_dev_metric = None 40 | self.best_dev_metrics = None 41 | self.metrics_history = [] 42 | 43 | @rank_zero_only 44 | def on_validation_end(self, trainer: LightningTrainer, pl_module: LightningModule): 45 | logger.info("") 46 | logger.info(f"***** Validation results at epoch {trainer.current_epoch} *****") 47 | 48 | assert pl_module.dataset.metric_watch_mode in {"max", "min"} 49 | self.metrics_history.append({}) 50 | 51 | metrics = trainer.callback_metrics 52 | # Log results 53 | for key in sorted(metrics): 54 | if key not in ["log", "progress_bar"]: 55 | logger.info("{} = {}".format(key, str(metrics[key]))) 56 | self.metrics_history[-1][key] = metrics[key] 57 | 58 | if key == pl_module.dataset.metric_to_watch and not trainer.sanity_checking: 59 | curr_metric = metrics[key] 60 | if ( 61 | self.best_dev_metric is None 62 | or ( 63 | pl_module.dataset.metric_watch_mode == "max" 64 | and curr_metric > self.best_dev_metric 65 | ) 66 | or ( 67 | pl_module.dataset.metric_watch_mode == "min" 68 | and curr_metric < self.best_dev_metric 69 | ) 70 | ): 71 | self.best_epoch = trainer.current_epoch 72 | self.best_dev_metric = curr_metric 73 | self.best_dev_metrics = { 74 | k: v 75 | for k, v in metrics.items() 76 | if k not in {"log", "progress_bar", "loss", "val_loss", "lr", "epoch"} 77 | } 78 | 79 | if not trainer.sanity_checking: 80 | logger.info(f"best_epoch = {self.best_epoch}") 81 | self.metrics_history[-1]["best_epoch"] = self.best_epoch 82 | for key, value in sorted(self.best_dev_metrics.items()): 83 | logger.info(f"best_{key} = {value}") 84 | self.metrics_history[-1][f"best_{key}"] = value 85 | 86 | 87 | @LightningCallback.register("t0_multitask") 88 | class T0MultiTaskCallback(LightningCallback): 89 | """ 90 | A Lightning callback for resampling the ``MixerDataset`` at the end of each epoch. 91 | """ 92 | 93 | def on_epoch_end(self, trainer: LightningTrainer, pl_module: LightningModule): 94 | assert isinstance(pl_module.dataset, T0MultiTaskDataModule) 95 | for dataset in pl_module.dataset.dataset_dict.values(): 96 | dataset.resample() 97 | 98 | 99 | # Since both FairScale and DeepSpeed are insane and will restart your whole process to make workers, we have 100 | # to be able to do this when train.py is called as a standalone script. 101 | def _train_step( 102 | work_dir: Path, 103 | config: Config, 104 | trainer: Lazy[LightningTrainer], 105 | strategy: Optional[str], 106 | model: Lazy[Model], 107 | datamodule: Lazy[PromptDataModule], 108 | ) -> Tuple[str, List[Dict]]: 109 | pl.seed_everything(config.seed) 110 | 111 | datamodule = datamodule.construct(config=config) 112 | 113 | datamodule.prepare_data() 114 | datamodule.setup() 115 | 116 | logger.info("Constructing trainer ...") 117 | trainer: LightningTrainer = trainer.construct( 118 | work_dir=work_dir, 119 | gpus=config.gpus, 120 | precision=config.precision, 121 | strategy=strategy, 122 | auto_select_gpus=config.auto_select_gpus, 123 | # Need to reload the dataloaders each epoch when using the T0MultiTaskDataModule. 124 | reload_dataloaders_every_n_epochs=1 if isinstance(datamodule, T0MultiTaskDataModule) else 0, 125 | ) 126 | logger.info("Done constructing trainer ...") 127 | 128 | # Make sure we're using the `T0MultiTaskCallback` if using the `T0MultiTaskDataModule` 129 | if isinstance(datamodule, T0MultiTaskDataModule): 130 | for callback in trainer.callbacks: 131 | if isinstance(callback, T0MultiTaskCallback): 132 | break 133 | else: 134 | raise RuntimeError("T0MultiTaskCallback required when using T0MultiTaskDataModule") 135 | 136 | epochs = trainer.max_epochs 137 | 138 | logger.info("Constructing model ...") 139 | model = model.construct( 140 | config=config, 141 | dataset=datamodule, 142 | epochs=epochs, 143 | accumulate_grad_batches=trainer.accumulate_grad_batches, 144 | ) 145 | logger.info("Done constructing model ...") 146 | 147 | assert model.epochs == epochs 148 | 149 | # Find the checkpoint callback and make sure it uses the right directory. 150 | # Also find the logging callback. 151 | checkpoint_callback: pl.callbacks.model_checkpoint.ModelCheckpoint = None 152 | logging_callback: LoggingCallback 153 | for callback in trainer.callbacks: 154 | if isinstance(callback, pl.callbacks.model_checkpoint.ModelCheckpoint): 155 | callback.dirpath = work_dir 156 | checkpoint_callback = callback 157 | if isinstance(callback, LoggingCallback): 158 | logging_callback = callback 159 | 160 | resume_from_checkpoint = None 161 | if "last.ckpt" in os.listdir(work_dir): 162 | resume_from_checkpoint = os.path.join(work_dir, "last.ckpt") 163 | trainer.fit(model, datamodule=datamodule, ckpt_path=resume_from_checkpoint) 164 | 165 | if not trainer.state.finished: 166 | raise ValueError(f"Trainer did not succeed! Final trainer state was {trainer.state}.") 167 | 168 | return ( 169 | checkpoint_callback.best_model_path if checkpoint_callback is not None else None, 170 | logging_callback.metrics_history, 171 | ) 172 | 173 | 174 | @Step.register("train_step") 175 | class TrainStep(Step): 176 | VERSION = "004" 177 | 178 | DETERMINISTIC: bool = True 179 | CACHEABLE = True 180 | FORMAT = JsonFormat() 181 | 182 | def run( # type: ignore[override] 183 | self, 184 | config: Config, 185 | trainer: Lazy[LightningTrainer], 186 | model: Lazy[Model], 187 | datamodule: Lazy[PromptDataModule], 188 | ) -> Tuple[str, List[Dict]]: 189 | if config.gpus == 1: 190 | strategy = None 191 | elif config.gpus > 1: 192 | # strategy = "deepspeed_stage_3_offload" 193 | # strategy = "deepspeed_stage_3" 194 | # strategy = "deepspeed_stage_2" 195 | # strategy = "ddp_sharded" 196 | # We never change trainability of parameters, so this is unnecessary. And actually 197 | # we rely on this flag being False for the current meta learning implementation. 198 | strategy = DDPShardedPlugin(auto_refresh_trainable=False) 199 | # strategy = "ddp" 200 | else: 201 | strategy = None 202 | 203 | kwargs_file = self.work_dir / "train_kwargs.dill" 204 | with kwargs_file.open("wb") as f: 205 | dill.dump( 206 | { 207 | "work_dir": self.work_dir, 208 | "extra_modules": get_extra_imported_modules(), 209 | "config": config, 210 | "trainer": trainer, 211 | "strategy": strategy, 212 | "model": model, 213 | "datamodule": datamodule, 214 | }, 215 | f, 216 | ) 217 | results_file = self.work_dir / "train_results.dill" 218 | 219 | import subprocess 220 | 221 | subprocess.check_call( 222 | [ 223 | sys.executable, 224 | "-m", 225 | "better_promptability.train.train_main", 226 | str(kwargs_file), 227 | str(results_file), 228 | ] 229 | ) 230 | with open(results_file, "rb") as f: 231 | results = dill.load(f) 232 | return results 233 | -------------------------------------------------------------------------------- /better_promptability/train/train_main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import dill 4 | from tango.common.logging import initialize_logging 5 | from tango.common.util import import_extra_module 6 | 7 | from better_promptability.train.train import _train_step 8 | 9 | 10 | def main(): 11 | initialize_logging() 12 | 13 | _, kwargs_file, results_file = sys.argv 14 | with open(kwargs_file, "rb") as f: 15 | training_kwargs = dill.load(f) 16 | for module in training_kwargs["extra_modules"]: 17 | import_extra_module(module) 18 | results = _train_step( 19 | training_kwargs["work_dir"], 20 | training_kwargs["config"], 21 | training_kwargs["trainer"], 22 | training_kwargs["strategy"], 23 | training_kwargs["model"], 24 | datamodule=training_kwargs["datamodule"], 25 | ) 26 | with open(results_file, "wb") as f: 27 | dill.dump(results, f) 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /better_promptability/version.py: -------------------------------------------------------------------------------- 1 | _MAJOR = "0" 2 | _MINOR = "1" 3 | # On main and in a nightly release the patch should be one ahead of the last 4 | # released build. 5 | _PATCH = "0" 6 | # This is mainly for nightly builds which have the suffix ".dev$DATE". See 7 | # https://semver.org/#is-v123-a-semantic-version for the semantics. 8 | _SUFFIX = "" 9 | 10 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 11 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) 12 | -------------------------------------------------------------------------------- /configs/0shot_eval.jsonnet: -------------------------------------------------------------------------------- 1 | local config = { 2 | "type": "default", 3 | "seed": 100, 4 | "gpus": 1, 5 | "precision": 32, 6 | }; 7 | local model_name = "google/t5-small-lm-adapt"; 8 | local mixture_name = "green"; 9 | local task_name = "hellaswag_Randomized_prompts_template_score_eval"; 10 | // local mixture_name = "d4_dev"; 11 | // local task_name = "race_high_Read_the_article_and_answer_the_question_no_option_"; 12 | local num_prefix = 0; 13 | 14 | // Set to null if you don't want to load a checkpoint. 15 | local checkpoint = null; 16 | 17 | local model = if checkpoint == "null" then { 18 | "type": "prefix_transformer", 19 | "transformer_model": model_name, 20 | } else { 21 | "type": "prefix_transformer_from_checkpoint", 22 | "transformer_model": model_name, 23 | "checkpoint_path": checkpoint, 24 | "strict": false, 25 | }; 26 | 27 | { 28 | "steps": { 29 | "output_model": { 30 | "type": "eval_step", 31 | "config": config, 32 | "trainer": { 33 | "type": "default", 34 | }, 35 | "datamodule": { 36 | "type": "t0", 37 | "mixture_name": mixture_name, 38 | "task_name": task_name, 39 | "data_dir": "data", 40 | "t0_data_cache": "/data/cl/user/zfw/better-promptability/t0_cache", 41 | "transformer_model": model_name, 42 | "num_prefix": num_prefix, 43 | "num_workers": 0, 44 | }, 45 | "model": model, 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /configs/0shot_eval_all_d4_dev.jsonnet: -------------------------------------------------------------------------------- 1 | local t0_mixtures = import 't0_mixtures.jsonnet'; 2 | local t0_task_info = import 't0_task_info.jsonnet'; 3 | 4 | // ------------------------------------ // 5 | // --- Mixture, datasets, and tasks --- // 6 | // ------------------------------------ // 7 | 8 | local mixture_name = "d4_dev"; 9 | 10 | local datasets = std.set([ 11 | t0_task_info["tasks"][task_name]["dataset_name"] for task_name in t0_mixtures[mixture_name] 12 | ]); 13 | local tasks = t0_mixtures[mixture_name]; 14 | 15 | // ----------------------- // 16 | // --- Hyperparameters --- // 17 | // ----------------------- // 18 | 19 | local config = { 20 | "type": "default", 21 | "seed": 100, 22 | "gpus": 1, 23 | "precision": 32, 24 | }; 25 | local model_name = "google/t5-small-lm-adapt"; 26 | 27 | // Set to null if you don't want to load a checkpoint. 28 | local checkpoint = std.extVar("CKPT"); 29 | 30 | local model = if checkpoint == "null" then { 31 | "type": "prefix_transformer", 32 | "transformer_model": model_name, 33 | } else { 34 | "type": "prefix_transformer_from_checkpoint", 35 | "transformer_model": model_name, 36 | "checkpoint_path": checkpoint, 37 | "strict": false, 38 | }; 39 | 40 | // Function that returns the eval step for a given task. 41 | local EvalStep(task_name) = { 42 | "type": "eval_step", 43 | "config": config, 44 | "trainer": { 45 | "type": "default", 46 | }, 47 | "datamodule": { 48 | "type": "t0", 49 | "mixture_name": mixture_name, 50 | "task_name": task_name, 51 | "data_dir": "data", 52 | "t0_data_cache": "/data/cl/user/zfw/better-promptability/t0_cache", 53 | "transformer_model": model_name, 54 | "num_prefix": 0, 55 | "num_workers": 0, 56 | }, 57 | "model": model, 58 | }; 59 | 60 | // Function that returns the name of the eval step for a task. 61 | local EvalStepName(task_name) = "result_" + task_name; 62 | 63 | { 64 | steps: { 65 | [EvalStepName(task_name)]: EvalStep(task_name) for task_name in tasks 66 | } + { 67 | "aggregated_results": { 68 | type: "aggregate_results", 69 | results: { 70 | [task_name]: {type: "ref", ref: EvalStepName(task_name)} 71 | for task_name in tasks 72 | } 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /configs/0shot_eval_all_green.jsonnet: -------------------------------------------------------------------------------- 1 | local t0_mixtures = import 't0_mixtures.jsonnet'; 2 | local t0_task_info = import 't0_task_info.jsonnet'; 3 | 4 | // ------------------------------------ // 5 | // --- Mixture, datasets, and tasks --- // 6 | // ------------------------------------ // 7 | 8 | local mixture_name = "green"; 9 | 10 | local datasets = std.set([ 11 | t0_task_info["tasks"][task_name]["dataset_name"] for task_name in t0_mixtures[mixture_name] 12 | ]); 13 | local tasks = t0_mixtures[mixture_name]; 14 | 15 | // ----------------------- // 16 | // --- Hyperparameters --- // 17 | // ----------------------- // 18 | 19 | local config = { 20 | "type": "default", 21 | "seed": 100, 22 | "gpus": 1, 23 | "precision": 32, 24 | }; 25 | local model_name = "google/t5-small-lm-adapt"; 26 | 27 | // Set to null if you don't want to load a checkpoint. 28 | local checkpoint = std.extVar("CKPT"); 29 | 30 | local model = if checkpoint == "null" then { 31 | "type": "prefix_transformer", 32 | "transformer_model": model_name, 33 | } else { 34 | "type": "prefix_transformer_from_checkpoint", 35 | "transformer_model": model_name, 36 | "checkpoint_path": checkpoint, 37 | "strict": false, 38 | }; 39 | 40 | // Function that returns the eval step for a given task. 41 | local EvalStep(task_name) = { 42 | "type": "eval_step", 43 | "config": config, 44 | "trainer": { 45 | "type": "default", 46 | }, 47 | "datamodule": { 48 | "type": "t0", 49 | "mixture_name": mixture_name, 50 | "task_name": task_name, 51 | "data_dir": "data", 52 | "t0_data_cache": "/data/cl/user/zfw/better-promptability/t0_cache/", 53 | "transformer_model": model_name, 54 | "num_prefix": 20, 55 | "num_workers": 0, 56 | "deep": true, 57 | }, 58 | "model": model, 59 | }; 60 | 61 | // Function that returns the name of the eval step for a task. 62 | local EvalStepName(task_name) = "result_" + task_name; 63 | 64 | { 65 | steps: { 66 | [EvalStepName(task_name)]: EvalStep(task_name) for task_name in tasks 67 | } + { 68 | "aggregated_results": { 69 | type: "aggregate_results", 70 | results: { 71 | [task_name]: {type: "ref", ref: EvalStepName(task_name)} 72 | for task_name in tasks 73 | } 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /configs/check_install.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | check_install: 3 | type: check_install 4 | -------------------------------------------------------------------------------- /configs/fewshot_eval.jsonnet: -------------------------------------------------------------------------------- 1 | local config = { 2 | "type": "default", 3 | "seed": 100, 4 | "gpus": 1, 5 | "precision": 32, 6 | }; 7 | local model_name = "google/t5-small-lm-adapt"; 8 | local mixture_name = "green"; 9 | local task_name = "hellaswag_Randomized_prompts_template_score_eval"; 10 | // local mixture_name = "d4_dev"; 11 | // local task_name = "race_high_Read_the_article_and_answer_the_question_no_option_"; 12 | local subsample_indices_file = "data/" + mixture_name + "_training_indices_16shot_100seed.pkl"; 13 | 14 | // Set to "null" if you don't want to load a checkpoint. 15 | local checkpoint = std.extVar("CKPT"); 16 | 17 | local optimizer = { 18 | "type": "adafactor", 19 | "lr": 0.001, 20 | "scale_parameter": false, 21 | "relative_step": false, 22 | }; 23 | 24 | local model = if checkpoint == "null" then { 25 | "type": "prefix_transformer", 26 | "transformer_model": model_name, 27 | "optimizer": optimizer, 28 | } else { 29 | "type": "prefix_transformer_from_checkpoint", 30 | "transformer_model": model_name, 31 | "optimizer": optimizer, 32 | "checkpoint_path": checkpoint, 33 | "strict": false, 34 | }; 35 | 36 | { 37 | "steps": { 38 | "output_model": { 39 | "type": "train_step", 40 | "config": config, 41 | "trainer": { 42 | "type": "default", 43 | "max_epochs": 100, 44 | "gradient_clip_val": 1.0, 45 | "accumulate_grad_batches": 1.0, 46 | "log_every_n_steps": 50, 47 | "logger": [ 48 | {"type": "pytorch_lightning::TensorBoardLogger"}, 49 | ], 50 | "callbacks": [ 51 | "my_logger", 52 | ], 53 | "enable_checkpointing": false, 54 | "replace_sampler_ddp": false, 55 | }, 56 | "datamodule": { 57 | "type": "t0", 58 | "mixture_name": mixture_name, 59 | "task_name": task_name, 60 | "data_dir": "data", 61 | "t0_data_cache": "/data/cl/user/zfw/better-promptability/t0_cache", 62 | "transformer_model": model_name, 63 | "num_prefix": 20, 64 | "subsample_indices_file": subsample_indices_file, 65 | "num_workers": 4, 66 | }, 67 | "model": model, 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /configs/fewshot_eval_all_d4_dev.jsonnet: -------------------------------------------------------------------------------- 1 | // This config is for running evaluations on a mixture of tasks and aggregating the results 2 | // by dataset and subset. 3 | // The aggregated results for each dataset will be in a step called "aggregated_results_{dataset_name}", 4 | // and the aggregated results for each subset will be in a step called "aggregated_results_{dataset_name}_{subset_name}". 5 | // 6 | // Testing: 7 | // -------- 8 | // 9 | // To do a test run, first modify this config like so: 10 | // 1. Set `model_name` to a small model like "google/t5-small-lm-adapt", 11 | // 2. Set `epochs` to a small number like 5, 12 | // 3. Override `datasets` to only list a couple of datasets in the mixture. 13 | // 4. Override `tasks` to only list one or two tasks for each dataset. 14 | // 15 | // Then run: 16 | // 17 | // $ tango run configs/fewshot_eval_all_dev.jsonnet -d /tmp/test-run 18 | 19 | local t0_mixtures = import 't0_mixtures.jsonnet'; 20 | local t0_task_info = import 't0_task_info.jsonnet'; 21 | 22 | // ------------------------------------ // 23 | // --- Mixture, datasets, and tasks --- // 24 | // ------------------------------------ // 25 | 26 | local mixture_name = "d4_dev"; 27 | 28 | local datasets = std.set([ 29 | t0_task_info["tasks"][task_name]["dataset_name"] for task_name in t0_mixtures[mixture_name] 30 | ]); 31 | local tasks = t0_mixtures[mixture_name]; 32 | 33 | // For debugging: 34 | // local datasets = ["anli", "hellaswag"]; 35 | // local tasks = [ 36 | // "anli_GPT_3_style_r1_score_eval", 37 | // "anli_GPT_3_style_r2_score_eval", 38 | // "hellaswag_Predict_ending_with_hint_score_eval", 39 | // "hellaswag_Randomized_prompts_template_score_eval", 40 | // ]; 41 | 42 | // ----------------------- // 43 | // --- Hyperparameters --- // 44 | // ----------------------- // 45 | 46 | local config = { 47 | type: "default", 48 | seed: 100, 49 | gpus: 1, 50 | precision: 32, 51 | }; 52 | 53 | local epochs = 100; 54 | 55 | local model_name = "google/t5-small-lm-adapt"; 56 | 57 | local checkpoint = std.extVar("CKPT"); 58 | 59 | local optimizer = { 60 | type: "adafactor", 61 | lr: 0.001, 62 | scale_parameter: false, 63 | relative_step: false, 64 | }; 65 | 66 | // Set to "true" to enable validation after every training epoch, otherwise we only validate 67 | // after the final epoch. 68 | local validate_every_epoch = false; 69 | 70 | // ------------------------------------------------------------ // 71 | // --- Data cache - edit according to the machine you're on --- // 72 | // ------------------------------------------------------------ // 73 | 74 | local t0_data_cache = "/data/cl/user/zfw/better-promptability/t0_cache"; 75 | 76 | // ----------------------------------------------------------- // 77 | // --- ! You probably don't need to edit below this line ! --- // 78 | // ----------------------------------------------------------- // 79 | 80 | local model = { 81 | "type": if checkpoint == "null" then "prefix_transformer" else "prefix_transformer_from_checkpoint", 82 | [if checkpoint == "null" then null else "checkpoint_path"]: checkpoint, 83 | transformer_model: model_name, 84 | optimizer: optimizer, 85 | }; 86 | 87 | // Function that returns the train + eval step for a given task. 88 | local TrainStep(task_name) = { 89 | type: "train_step", 90 | config: config, 91 | trainer: { 92 | type: "default", 93 | max_epochs: epochs, 94 | gradient_clip_val: 1.0, 95 | accumulate_grad_batches: 1.0, 96 | log_every_n_steps: 50, 97 | logger: [ 98 | {type: "pytorch_lightning::TensorBoardLogger"}, 99 | ], 100 | callbacks: [ 101 | "my_logger", 102 | ], 103 | enable_checkpointing: false, 104 | replace_sampler_ddp: false, 105 | check_val_every_n_epoch: if validate_every_epoch then 1 else epochs, 106 | }, 107 | datamodule: { 108 | type: "t0", 109 | mixture_name: mixture_name, 110 | task_name: task_name, 111 | data_dir: "data", 112 | t0_data_cache: t0_data_cache, 113 | transformer_model: model_name, 114 | num_prefix: 20, 115 | subsample_indices_file: "data/" + mixture_name + "_training_indices_16shot_100seed.pkl", 116 | num_workers: 4, 117 | }, 118 | model: model, 119 | }; 120 | 121 | // Function that returns the name of the train+eval step for a task. 122 | local TrainStepName(task_name) = "result_" + task_name; 123 | 124 | { 125 | steps: { 126 | [TrainStepName(task_name)]: TrainStep(task_name) for task_name in tasks 127 | } + { 128 | "aggregated_results": { 129 | type: "aggregate_results", 130 | results: { 131 | [task_name]: {type: "ref", ref: TrainStepName(task_name)} 132 | for task_name in tasks 133 | } 134 | } 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /configs/fewshot_eval_all_green.jsonnet: -------------------------------------------------------------------------------- 1 | // This config is for running evaluations on a mixture of tasks and aggregating the results 2 | // by dataset and subset. 3 | // The aggregated results for each dataset will be in a step called "aggregated_results_{dataset_name}", 4 | // and the aggregated results for each subset will be in a step called "aggregated_results_{dataset_name}_{subset_name}". 5 | // 6 | // Testing: 7 | // -------- 8 | // 9 | // To do a test run, first modify this config like so: 10 | // 1. Set `model_name` to a small model like "google/t5-small-lm-adapt", 11 | // 2. Set `epochs` to a small number like 5, 12 | // 3. Override `datasets` to only list a couple of datasets in the mixture. 13 | // 4. Override `tasks` to only list one or two tasks for each dataset. 14 | // 15 | // Then run: 16 | // 17 | // $ tango run configs/fewshot_eval_all_green.jsonnet -d /tmp/test-run 18 | 19 | local t0_mixtures = import 't0_mixtures.jsonnet'; 20 | local t0_task_info = import 't0_task_info.jsonnet'; 21 | 22 | // ------------------------------------ // 23 | // --- Mixture, datasets, and tasks --- // 24 | // ------------------------------------ // 25 | 26 | local mixture_name = "green"; 27 | 28 | local datasets = std.set([ 29 | t0_task_info["tasks"][task_name]["dataset_name"] for task_name in t0_mixtures[mixture_name] 30 | ]); 31 | local tasks = t0_mixtures[mixture_name]; 32 | 33 | // For debugging: 34 | // local datasets = ["anli", "hellaswag"]; 35 | // local tasks = [ 36 | // "anli_GPT_3_style_r1_score_eval", 37 | // "anli_GPT_3_style_r2_score_eval", 38 | // "hellaswag_Predict_ending_with_hint_score_eval", 39 | // "hellaswag_Randomized_prompts_template_score_eval", 40 | // ]; 41 | 42 | // ----------------------- // 43 | // --- Hyperparameters --- // 44 | // ----------------------- // 45 | 46 | local config = { 47 | type: "default", 48 | seed: 100, 49 | gpus: 1, 50 | precision: 32, 51 | }; 52 | 53 | local epochs = 100; 54 | 55 | local model_name = "google/t5-small-lm-adapt"; 56 | 57 | local checkpoint = std.extVar("CKPT"); 58 | 59 | local optimizer = { 60 | type: "adafactor", 61 | lr: 0.001, 62 | scale_parameter: false, 63 | relative_step: false, 64 | }; 65 | 66 | // Set to "true" to enable validation after every training epoch, otherwise we only validate 67 | // after the final epoch. 68 | local validate_every_epoch = false; 69 | 70 | // ------------------------------------------------------------ // 71 | // --- Data cache - edit according to the machine you're on --- // 72 | // ------------------------------------------------------------ // 73 | 74 | local t0_data_cache = "/data/cl/user/zfw/better-promptability/t0_cache/"; 75 | 76 | // ----------------------------------------------------------- // 77 | // --- ! You probably don't need to edit below this line ! --- // 78 | // ----------------------------------------------------------- // 79 | 80 | local model = { 81 | "type": if checkpoint == "null" then "prefix_transformer" else "prefix_transformer_from_checkpoint", 82 | [if checkpoint == "null" then null else "checkpoint_path"]: checkpoint, 83 | transformer_model: model_name, 84 | optimizer: optimizer, 85 | }; 86 | 87 | // Function that returns the train + eval step for a given task. 88 | local TrainStep(task_name) = { 89 | type: "train_step", 90 | config: config, 91 | trainer: { 92 | type: "default", 93 | max_epochs: epochs, 94 | gradient_clip_val: 1.0, 95 | accumulate_grad_batches: 1.0, 96 | log_every_n_steps: 50, 97 | logger: [ 98 | {type: "pytorch_lightning::TensorBoardLogger"}, 99 | ], 100 | callbacks: [ 101 | "my_logger", 102 | ], 103 | enable_checkpointing: false, 104 | replace_sampler_ddp: false, 105 | check_val_every_n_epoch: if validate_every_epoch then 1 else epochs, 106 | }, 107 | datamodule: { 108 | type: "t0", 109 | mixture_name: mixture_name, 110 | task_name: task_name, 111 | data_dir: "data", 112 | t0_data_cache: t0_data_cache, 113 | transformer_model: model_name, 114 | num_prefix: 20, 115 | subsample_indices_file: "data/" + mixture_name + "_training_indices_16shot_100seed.pkl", 116 | num_workers: 4, 117 | deep: true, 118 | }, 119 | model: model, 120 | }; 121 | 122 | // Function that returns the name of the train+eval step for a task. 123 | local TrainStepName(task_name) = "result_" + task_name; 124 | 125 | { 126 | steps: { 127 | [TrainStepName(task_name)]: TrainStep(task_name) for task_name in tasks 128 | } + { 129 | "aggregated_results": { 130 | type: "aggregate_results", 131 | results: { 132 | [task_name]: {type: "ref", ref: TrainStepName(task_name)} 133 | for task_name in tasks 134 | } 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /configs/fomaml.jsonnet: -------------------------------------------------------------------------------- 1 | local config = { 2 | "type": "default", 3 | "seed": 100, 4 | "gpus": 2, 5 | "precision": 32, 6 | }; 7 | local model = "google/t5-small-lm-adapt"; 8 | 9 | local meta_batch_size = 128; 10 | local ckpt_interval = 65536 / meta_batch_size; 11 | 12 | { 13 | "steps": { 14 | "output_model": { 15 | "type": "train_step", 16 | "config": config, 17 | "trainer": { 18 | "type": "default", 19 | "max_epochs": 100, 20 | "gradient_clip_val": 1.0, 21 | "accumulate_grad_batches": 1.0, 22 | "num_sanity_val_steps": 0, 23 | "log_every_n_steps": 6, 24 | "val_check_interval": ckpt_interval / config.gpus, 25 | "logger": [ 26 | {"type": "pytorch_lightning::TensorBoardLogger"}, 27 | ], 28 | "callbacks": [ 29 | # We need separate ModelCheckpoints for per-step and per-epoch checkpointing. 30 | # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11645 31 | # and https://github.com/PyTorchLightning/pytorch-lightning/issues/11930 32 | { 33 | "type": "pytorch_lightning::ModelCheckpoint", 34 | "save_last": true, 35 | "save_top_k": -1, 36 | "filename": "{epoch}-{step}-{categorical_accuracy:.4f}", 37 | "save_on_train_epoch_end": false, 38 | }, 39 | { 40 | "type": "pytorch_lightning::ModelCheckpoint", 41 | "save_last": true, 42 | "save_top_k": -1, 43 | "filename": "{epoch}-{step}-endofepoch-{categorical_accuracy:.4f}", 44 | "save_on_train_epoch_end": true, 45 | }, 46 | "my_logger", 47 | "t0_multitask", 48 | ], 49 | "replace_sampler_ddp": false, 50 | }, 51 | "datamodule": { 52 | "type": "t0_meta_learning", 53 | "meta_batch_size": meta_batch_size, 54 | "mixture_name": "d4_train", 55 | "data_dir": "data", 56 | "t0_data_cache": "/data/cl/user/zfw/better-promptability/t0_cache/", 57 | "transformer_model": model, 58 | "batch_size": 32, # this is the effective batch size; ONLY change meta_accumulate_grad_batches when adjusting for GPU sizes 59 | "support_batch_size": 16, # ditto 60 | "eval_batch_size": 64, 61 | "num_prefix": 20, 62 | "num_workers": 4, 63 | "deep": true, 64 | }, 65 | "model": { 66 | "type": "meta_learner", 67 | "model": { 68 | "transformer_model": model, 69 | "optimizer": { 70 | "type": "transformers::adafactor", 71 | "lr": 0.001, 72 | "scale_parameter": false, 73 | "relative_step": false, 74 | }, 75 | }, 76 | "adaptation_steps": 7, # though in few-shot learning we have only one batch/epoch, but we train for many epochs 77 | "algorithm": "fomaml", 78 | "meta_optimizer": { 79 | "type": "transformers::adafactor", 80 | "lr": 0.001, 81 | "scale_parameter": false, 82 | "relative_step": false, 83 | }, 84 | "meta_accumulate_grad_batches": 16, 85 | } // "model" (meta_learner) 86 | } // "output_model" 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /configs/multi_task.jsonnet: -------------------------------------------------------------------------------- 1 | local config = { 2 | "type": "default", 3 | "seed": 100, 4 | "gpus": 2, 5 | "precision": 32, 6 | }; 7 | local model = "google/t5-small-lm-adapt"; 8 | local train_full_model = true; 9 | local effective_batch_size = 4096; 10 | local batch_size = 2; 11 | local ckpt_interval = 512; 12 | 13 | { 14 | "steps": { 15 | "output_model": { 16 | "type": "train_step", 17 | "config": config, 18 | "trainer": { 19 | "type": "default", 20 | "max_epochs": 1, 21 | "gradient_clip_val": 1.0, 22 | "accumulate_grad_batches": effective_batch_size / batch_size / config.gpus, 23 | "num_sanity_val_steps": 0, 24 | "log_every_n_steps": 50, 25 | "val_check_interval": ckpt_interval * effective_batch_size / batch_size / config.gpus, 26 | "logger": [ 27 | {"type": "pytorch_lightning::TensorBoardLogger"}, 28 | ], 29 | "callbacks": [ 30 | # We need separate ModelCheckpoints for per-step and per-epoch checkpointing. 31 | # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11645 32 | # and https://github.com/PyTorchLightning/pytorch-lightning/issues/11930 33 | { 34 | "type": "pytorch_lightning::ModelCheckpoint", 35 | "save_last": true, 36 | "save_top_k": -1, 37 | "filename": "{epoch}-{step}-{categorical_accuracy:.4f}", 38 | "save_on_train_epoch_end": false, 39 | }, 40 | { 41 | "type": "pytorch_lightning::ModelCheckpoint", 42 | "save_last": true, 43 | "save_top_k": -1, 44 | "filename": "{epoch}-{step}-endofepoch-{categorical_accuracy:.4f}", 45 | "save_on_train_epoch_end": true, 46 | }, 47 | "my_logger", 48 | "t0_multitask", 49 | ], 50 | "replace_sampler_ddp": false, 51 | }, 52 | "model": { 53 | "type": "prefix_transformer", 54 | "transformer_model": model, 55 | "optimizer": { 56 | "type": "transformers::adafactor", 57 | "lr": 0.001, 58 | "scale_parameter": false, 59 | "relative_step": false, 60 | }, 61 | "train_full_model": train_full_model, 62 | }, 63 | "datamodule": { 64 | "type": "t0_multitask", 65 | "mixture_name": "d4_train", 66 | "data_dir": "data", 67 | "t0_data_cache": "/data/cl/user/zfw/better-promptability/t0_cache/", 68 | "transformer_model": model, 69 | "batch_size": batch_size, 70 | "eval_batch_size": 64, 71 | "num_prefix": 20, 72 | "num_workers": 4, 73 | "deep": true, 74 | }, 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /configs/reptile.jsonnet: -------------------------------------------------------------------------------- 1 | local config = { 2 | "type": "default", 3 | "seed": 100, 4 | "gpus": 2, 5 | "precision": 32, 6 | }; 7 | local model = "google/t5-small-lm-adapt"; 8 | 9 | local meta_batch_size = 128; 10 | local adaptation_steps = 7; 11 | local ckpt_interval = 65536 / meta_batch_size; 12 | 13 | { 14 | "steps": { 15 | "output_model": { 16 | "type": "train_step", 17 | "config": config, 18 | "trainer": { 19 | "type": "default", 20 | "max_epochs": 100, 21 | "gradient_clip_val": 1.0, 22 | "accumulate_grad_batches": 1.0, 23 | "num_sanity_val_steps": 0, 24 | "log_every_n_steps": 6, 25 | "val_check_interval": ckpt_interval / config.gpus, 26 | "logger": [ 27 | {"type": "pytorch_lightning::TensorBoardLogger"}, 28 | ], 29 | "callbacks": [ 30 | # We need separate ModelCheckpoints for per-step and per-epoch checkpointing. 31 | # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11645 32 | # and https://github.com/PyTorchLightning/pytorch-lightning/issues/11930 33 | { 34 | "type": "pytorch_lightning::ModelCheckpoint", 35 | "save_last": true, 36 | "save_top_k": -1, 37 | "filename": "{epoch}-{step}-{categorical_accuracy:.4f}", 38 | "save_on_train_epoch_end": false, 39 | }, 40 | { 41 | "type": "pytorch_lightning::ModelCheckpoint", 42 | "save_last": true, 43 | "save_top_k": -1, 44 | "filename": "{epoch}-{step}-endofepoch-{categorical_accuracy:.4f}", 45 | "save_on_train_epoch_end": true, 46 | }, 47 | "my_logger", 48 | "t0_multitask", 49 | ], 50 | "replace_sampler_ddp": false, 51 | }, 52 | "datamodule": { 53 | "type": "t0_meta_learning", 54 | "meta_batch_size": meta_batch_size, 55 | "mixture_name": "d4_train", 56 | "data_dir": "data", 57 | "t0_data_cache": "/data/cl/user/zfw/better-promptability/t0_cache/", 58 | "transformer_model": model, 59 | "batch_size": 16 * (adaptation_steps + 1), # this is the effective batch size; ONLY change meta_accumulate_grad_batches when adjusting for GPU sizes 60 | "support_batch_size": 16 * adaptation_steps, # ditto 61 | "eval_batch_size": 64, 62 | "num_prefix": 20, 63 | "num_workers": 4, 64 | "deep": true, 65 | }, 66 | "model": { 67 | "type": "meta_learner", 68 | "model": { 69 | "transformer_model": model, 70 | "optimizer": { 71 | "type": "transformers::adafactor", 72 | "lr": 0.001, 73 | "scale_parameter": false, 74 | "relative_step": false, 75 | }, 76 | }, 77 | "adaptation_steps": adaptation_steps, 78 | "algorithm": "reptile", 79 | "different_inner_loop_batches": true, 80 | "meta_optimizer": { 81 | "type": "transformers::adafactor", 82 | "lr": 0.001, 83 | "scale_parameter": false, 84 | "relative_step": false, 85 | }, 86 | "meta_accumulate_grad_batches": 16, 87 | } // "model" (meta_learner) 88 | } // "output_model" 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/data/.gitkeep -------------------------------------------------------------------------------- /data/d4_dev_training_indices_16shot_100seed.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/data/d4_dev_training_indices_16shot_100seed.pkl -------------------------------------------------------------------------------- /data/green_training_indices_16shot_100seed.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/data/green_training_indices_16shot_100seed.pkl -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = true 3 | no_site_packages = true 4 | allow_redefinition = true 5 | 6 | [mypy-tests.*] 7 | strict_optional = false 8 | -------------------------------------------------------------------------------- /output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/output/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | 4 | include = '\.pyi?$' 5 | 6 | exclude = ''' 7 | ( 8 | __pycache__ 9 | | \.git 10 | | \.mypy_cache 11 | | \.pytest_cache 12 | | \.vscode 13 | | \.venv 14 | | \bdist\b 15 | | \bdoc\b 16 | ) 17 | ''' 18 | 19 | [build-system] 20 | requires = ["setuptools", "wheel"] 21 | build-backend = "setuptools.build_meta" 22 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests/ 3 | python_classes = Test* *Test 4 | log_format = %(asctime)s - %(levelname)s - %(name)s - %(message)s 5 | log_level = DEBUG 6 | markers = 7 | filterwarnings = 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.1 2 | ai2-tango[all]==0.5.0 3 | allennlp==2.9.0 4 | pytorch-lightning==1.5.5 5 | torchmetrics==0.6.2 6 | transformers==4.12.5 7 | deepspeed==0.5.10 8 | -------------------------------------------------------------------------------- /scripts/bootstrap.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | 4 | from tqdm import tqdm 5 | 6 | 7 | random.seed(100) 8 | 9 | 10 | TASKS_METADATA = [ # task name, num templates, random performance 11 | ("ANLI", 45, 1/3), 12 | ("Hellaswag", 4, 1/4), 13 | ("StoryCloze", 5, 1/2), 14 | ("CB", 15, 1/3), 15 | ("COPA", 12, 1/2), 16 | ("RTE", 10, 1/2), 17 | ("WIC", 10, 1/2), 18 | ("WSC", 10, 1/2), 19 | ("Winogrande", 5, 1/2), 20 | ] 21 | NUM_INSTANCES = [1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,1000,1000,1200,10042,10042,10042,10042,1871,1871,1871,1871,1871,56,56,56,56,56,56,56,56,56,56,56,56,56,56,56,48,100,100,100,100,100,100,100,100,48,52,52,277,277,277,277,277,277,277,277,277,277,638,638,638,638,638,638,638,638,638,638,104,104,104,104,104,104,104,104,104,104,1267,1267,1267,1267,1267] 22 | RANDOM_SCORES = [e for metadata in TASKS_METADATA for e in [metadata[2]] * metadata[1]] 23 | 24 | XL_T0REPRO_DEEP = [0.382999986410141,0.3400000035762787,0.32499998807907104,0.3630000054836273,0.3160000145435333,0.398333340883255,0.39500001072883606,0.3569999933242798,0.40583333373069763,0.34299999475479126,0.3619999885559082,0.4008333384990692,0.3720000088214874,0.34299999475479126,0.3166666626930237,0.37599998712539673,0.3630000054836273,0.3958333432674408,0.34599998593330383,0.35499998927116394,0.40416666865348816,0.3709999918937683,0.3479999899864197,0.36916667222976685,0.3400000035762787,0.35899999737739563,0.3408333361148834,0.4009999930858612,0.34200000762939453,0.36916667222976685,0.3659999966621399,0.37599998712539673,0.41083332896232605,0.375,0.3440000116825104,0.398333340883255,0.3930000066757202,0.3310000002384186,0.3824999928474426,0.38499999046325684,0.36500000953674316,0.3916666805744171,0.35600000619888306,0.34700000286102295,0.4258333444595337,0.28719377517700195,0.2502489686012268,0.28779128193855286,0.2903803884983063,0.9182255268096924,0.911811888217926,0.9262426495552063,0.9160876274108887,0.9203634262084961,0.8214285969734192,0.8392857313156128,0.6785714030265808,0.7857142686843872,0.8214285969734192,0.8035714030265808,0.7678571343421936,0.7857142686843872,0.6785714030265808,0.8392857313156128,0.8035714030265808,0.8214285969734192,0.8571428656578064,0.8035714030265808,0.5714285969734192,0.875,0.7900000214576721,0.7200000286102295,0.7900000214576721,0.8500000238418579,0.8100000023841858,0.7799999713897705,0.8299999833106995,0.8299999833106995,0.9166666865348816,0.7692307829856873,0.7692307829856873,0.8122743964195251,0.7653429508209229,0.8050541281700134,0.833935022354126,0.8050541281700134,0.7870036363601685,0.7689530849456787,0.7833935022354126,0.7761732935905457,0.7797833681106567,0.5470219254493713,0.5250783562660217,0.5297805666923523,0.554858922958374,0.5423197746276855,0.5313479900360107,0.5595611333847046,0.5250783562660217,0.5329153537750244,0.5423197746276855,0.49038460850715637,0.5288461446762085,0.4423076808452606,0.5192307829856873,0.6538461446762085,0.682692289352417,0.5480769276618958,0.5480769276618958,0.567307710647583,0.5769230723381042,0.5406472086906433,0.5548539757728577,0.5390686392784119,0.518547773361206,0.5351223349571228] 25 | XL_MTL_DEEP = [0.36000001430511475,0.36000001430511475,0.335833340883255,0.4129999876022339,0.33399999141693115,0.4099999964237213,0.41600000858306885,0.3499999940395355,0.3916666805744171,0.33000001311302185,0.36500000953674316,0.3916666805744171,0.3540000021457672,0.34299999475479126,0.3241666555404663,0.39899998903274536,0.3930000066757202,0.39500001072883606,0.40299999713897705,0.375,0.4099999964237213,0.37700000405311584,0.34200000762939453,0.3708333373069763,0.35899999737739563,0.36500000953674316,0.3774999976158142,0.41600000858306885,0.35600000619888306,0.36666667461395264,0.36500000953674316,0.3619999885559082,0.4050000011920929,0.3799999952316284,0.36399999260902405,0.3841666579246521,0.4020000100135803,0.32499998807907104,0.3933333456516266,0.40799999237060547,0.3799999952316284,0.37833333015441895,0.3709999918937683,0.36899998784065247,0.4074999988079071,0.30860385298728943,0.26986655592918396,0.29247161746025085,0.29784902930259705,0.920897901058197,0.9326563477516174,0.9321218729019165,0.921432375907898,0.9203634262084961,0.8035714030265808,0.8214285969734192,0.8214285969734192,0.8035714030265808,0.7678571343421936,0.7857142686843872,0.8571428656578064,0.8392857313156128,0.7857142686843872,0.8392857313156128,0.75,0.8035714030265808,0.8214285969734192,0.8571428656578064,0.75,0.7291666865348816,0.8100000023841858,0.7099999785423279,0.7699999809265137,0.800000011920929,0.800000011920929,0.7599999904632568,0.8299999833106995,0.8100000023841858,0.7708333134651184,0.75,0.7307692170143127,0.7797833681106567,0.7292418479919434,0.7039711475372314,0.7942238450050354,0.7328519821166992,0.6931408047676086,0.7039711475372314,0.750902533531189,0.6859205961227417,0.7220216393470764,0.5768024921417236,0.5219435691833496,0.5454545617103577,0.5877742767333984,0.5877742767333984,0.5611284971237183,0.5203761458396912,0.5094043612480164,0.568965494632721,0.5626959204673767,0.4711538553237915,0.5769230723381042,0.5192307829856873,0.6346153616905212,0.6346153616905212,0.6538461446762085,0.4711538553237915,0.4711538553237915,0.5,0.5769230723381042,0.5217047929763794,0.5611681342124939,0.5730071067810059,0.5493291020393372,0.5603788495063782] 26 | 27 | assert all(len(RANDOM_SCORES) == len(NUM_INSTANCES) == sum(m[1] for m in TASKS_METADATA) == len(l) for k, l in globals().items() if any(k.startswith(p) for p in ("LARGE", "XL", "T0"))) 28 | 29 | 30 | def avg(l): 31 | return sum(l) / len(l) 32 | 33 | 34 | def macro_avg(l): 35 | per_task = [] 36 | for _, num_prompts, _ in TASKS_METADATA: 37 | per_task.append(avg(l[:num_prompts])) 38 | l = l[num_prompts:] 39 | assert len(l) == 0 40 | return avg(per_task) 41 | 42 | 43 | def arg(results): 44 | assert len(RANDOM_SCORES) == len(results) 45 | scores = [sum(r) / num for r, num in zip(results, NUM_INSTANCES)] 46 | rgs = [(score - baseline) / baseline for baseline, score in zip(RANDOM_SCORES, scores)] 47 | return macro_avg(rgs) 48 | 49 | 50 | def pairwise_test(worse_scores, better_scores): 51 | worse_n_correct = [round(score * num) for score, num in zip(worse_scores, NUM_INSTANCES)] 52 | better_n_correct = [round(score * num) for score, num in zip(better_scores, NUM_INSTANCES)] 53 | worse_results = [[1] * n_correct + [0] * (num - n_correct) for n_correct, num in zip(worse_n_correct, NUM_INSTANCES)] 54 | better_results = [[1] * n_correct + [0] * (num - n_correct) for n_correct, num in zip(better_n_correct, NUM_INSTANCES)] 55 | 56 | print(f"Original ARG: worse {arg(worse_results)}, better {arg(better_results)}") 57 | 58 | arg_diffs = [] 59 | for _ in tqdm(range(1000)): 60 | new_worse_results = [] 61 | new_better_results = [] 62 | for worse, better in zip(worse_results, better_results): 63 | new_worse, new_better = zip(*random.choices(list(zip(worse, better)), k=len(worse))) 64 | new_worse_results.append(new_worse) 65 | new_better_results.append(new_better) 66 | worse_arg = arg(new_worse_results) 67 | better_arg = arg(new_better_results) 68 | arg_diffs.append(better_arg - worse_arg) 69 | 70 | print(f"arg p: {avg([d < 0 for d in arg_diffs])}") 71 | 72 | 73 | def main(): 74 | pairwise_test(XL_T0REPRO_DEEP, XL_MTL_DEEP) 75 | 76 | 77 | if __name__ == "__main__": 78 | main(*sys.argv[1:]) # pylint: disable=no-value-for-parameter 79 | -------------------------------------------------------------------------------- /scripts/download_t0_training_set.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download all of the data from the [bigscience/P3](https://huggingface.co/datasets/bigscience/P3) 3 | corresponding to a particular mixture. This script should only be run from the root of this repository. 4 | """ 5 | 6 | import importlib 7 | import json 8 | import os 9 | import sys 10 | from pathlib import Path 11 | 12 | import datasets 13 | from tango.common import Params 14 | from tango.common.file_lock import FileLock 15 | from tqdm import tqdm 16 | 17 | STORY_CLOZE_PATH = Path("/data/cl/user/zfw/story_cloze_dir") 18 | 19 | 20 | def main(mixture_name: str, cache_dir: str): 21 | cache_dir = Path(cache_dir) 22 | 23 | def download_task_dataset(task_name: str): 24 | local_path = cache_dir / task_name # type: ignore 25 | if not os.path.isdir(local_path) or not os.listdir(local_path): 26 | if task_name.startswith("story_cloze_"): 27 | data_dir = STORY_CLOZE_PATH / task_name 28 | # Hack to add story cloze to the config in the P3 dataset builder -- import it first 29 | # and change relevant data structures 30 | dataset_module = datasets.load.dataset_module_factory( 31 | "bigscience/P3", 32 | revision=None, 33 | download_config=None, 34 | download_mode=None, 35 | data_files=None, 36 | ) 37 | p3_module = importlib.import_module(dataset_module.module_path) 38 | 39 | # Mostly following https://huggingface.co/datasets/bigscience/P3/blob/main/P3.py 40 | task_splits_and_features = p3_module._TASK_SPLITS_AND_FEATURES_DICT # type: ignore 41 | assert task_name not in task_splits_and_features 42 | for split_name in ("validation", "test"): # story cloze has no training set 43 | split_info = json.load(open(data_dir / f"info.{split_name}.json")) 44 | features_dict = split_info["features"] 45 | assert split_info["num_shards"] == 1 46 | 47 | if task_name not in task_splits_and_features: 48 | task_splits_and_features[task_name] = { 49 | "splits": [], 50 | "features_dict": features_dict, 51 | } 52 | task_splits_and_features[task_name]["splits"].append(split_name) 53 | assert features_dict == task_splits_and_features[task_name]["features_dict"] 54 | splits_and_features_dict = task_splits_and_features[task_name] 55 | 56 | assert task_name not in p3_module._URLs # type: ignore 57 | p3_module._URLs[task_name] = { # type: ignore 58 | split_name: {"tfrecord": data_dir / f"{split_name}.tfrecord-00000-of-00001"} 59 | for split_name in splits_and_features_dict["splits"] 60 | } 61 | 62 | p3_module.P3.BUILDER_CONFIGS.append( # type: ignore 63 | p3_module.P3Config( # type: ignore 64 | name=task_name, 65 | splits=splits_and_features_dict["splits"], 66 | features_dict=splits_and_features_dict["features_dict"], 67 | score_eval=task_name.endswith("score_eval"), 68 | ) 69 | ) 70 | p3_module.P3.builder_configs = { # type: ignore 71 | config.name: config for config in p3_module.P3.BUILDER_CONFIGS # type: ignore 72 | } 73 | 74 | retries = 0 75 | while True: 76 | try: 77 | dataset = datasets.load_dataset("bigscience/P3", task_name) 78 | break 79 | except ConnectionError: 80 | retries += 1 81 | if retries > 3: 82 | raise 83 | 84 | with FileLock(str(local_path) + ".lock"): 85 | dataset.save_to_disk(local_path) 86 | 87 | tasks = Params.from_file("configs/t0_mixtures.jsonnet")[mixture_name] 88 | 89 | for task in tqdm(tasks): 90 | download_task_dataset(task) 91 | 92 | 93 | if __name__ == "__main__": 94 | main(*sys.argv[1:]) 95 | -------------------------------------------------------------------------------- /scripts/process_green_datasets.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from tango.common import Params 6 | from tqdm import tqdm 7 | 8 | from better_promptability.steps.process_dataset import ProcessDataset 9 | from better_promptability.steps.process_story_cloze import ProcessStoryCloze 10 | 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | 15 | def process_green_datasets(old_base_path, new_base_path): 16 | datasets = Params.from_file("configs/t0_mixtures.jsonnet")["green"] 17 | for dataset in tqdm(datasets): 18 | dataset = dataset.strip() 19 | if "story_cloze" not in dataset: 20 | step = ProcessDataset() 21 | else: 22 | step = ProcessStoryCloze() 23 | try: 24 | step.run( 25 | old_data_path=os.path.join(old_base_path, dataset), 26 | new_data_path=os.path.join(new_base_path, dataset), 27 | ) 28 | except KeyError: 29 | print(f"error in {dataset}") 30 | 31 | 32 | if __name__ == "__main__": 33 | process_green_datasets(sys.argv[1], sys.argv[2]) 34 | -------------------------------------------------------------------------------- /scripts/subsample_t0_training_set.py: -------------------------------------------------------------------------------- 1 | """ 2 | Subsamples the training set for each dataset (i.e., for all tepmlates). 3 | Ideally we want to sample the same examples across templates for a given dataset, but unfortunately 4 | this is impossible since the P3 dataset cache does not guarantee the same example order across 5 | templates. Check out, for example, hellaswag_complete_first_then_score_eval[29372] and 6 | hellaswag_Predict_ending_with_hint_score_eval[29372]. 7 | """ 8 | 9 | from pathlib import Path 10 | import pickle 11 | import sys 12 | import random 13 | 14 | from tqdm import tqdm 15 | 16 | sys.path.append(str(Path(__file__).parent.parent.absolute())) 17 | 18 | from better_promptability.data.config import Config # noqa: E402 19 | from better_promptability.data.data_utils import md5 # noqa: E402 20 | from better_promptability.data.t0_mixture import T0Mixture # noqa: E402 21 | 22 | 23 | def main(mixture_name, n_shot, seed, output_file): 24 | n_shot = int(n_shot) 25 | seed = int(seed) 26 | random.seed(seed) 27 | 28 | # All arguments apart from the first two are dummy 29 | mixture = T0Mixture( 30 | mixture_name=mixture_name, 31 | t0_data_cache="/data/cl/user/zfw/better-promptability/t0_cache/", 32 | config=Config(), 33 | data_dir="tmp", 34 | num_prefix=20, 35 | transformer_model="t5-base", 36 | ) 37 | taskname_to_indices = {} 38 | for data_module in tqdm(mixture.data_modules.values()): 39 | task_name = data_module.task_name 40 | dataset_dict = data_module.load() 41 | train_split = dataset_dict[data_module.train_split] 42 | total_len = len(train_split) 43 | print(f"Sampling {n_shot} examples from {total_len} for {task_name} with seed {seed}") 44 | indices = random.sample(range(total_len), n_shot) 45 | checksum = md5( 46 | "".join(str(train_split[i]["inputs"] + train_split[i]["targets"]) for i in indices) 47 | ) 48 | taskname_to_indices[task_name] = (indices, checksum) 49 | 50 | pickle.dump(taskname_to_indices, open(output_file, "wb")) 51 | 52 | 53 | if __name__ == "__main__": 54 | main(*sys.argv[1:]) # pylint: disable=no-value-for-parameter 55 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | def read_requirements(filename: str): 5 | with open(filename) as requirements_file: 6 | import re 7 | 8 | def fix_url_dependencies(req: str) -> str: 9 | """Pip and setuptools disagree about how URL dependencies should be handled.""" 10 | m = re.match( 11 | r"^(git\+)?(https|ssh)://(git@)?github\.com/([\w-]+)/(?P[\w-]+)\.git", req 12 | ) 13 | if m is None: 14 | return req 15 | else: 16 | return f"{m.group('name')} @ {req}" 17 | 18 | requirements = [] 19 | for line in requirements_file: 20 | line = line.strip() 21 | if line.startswith("#") or len(line) <= 0: 22 | continue 23 | requirements.append(fix_url_dependencies(line)) 24 | return requirements 25 | 26 | 27 | # version.py defines the VERSION and VERSION_SHORT variables. 28 | # We use exec here so we don't import cached_path whilst setting up. 29 | VERSION = {} # type: ignore 30 | with open("better_promptability/version.py", "r") as version_file: 31 | exec(version_file.read(), VERSION) 32 | 33 | setup( 34 | name="better_promptability", 35 | version=VERSION["VERSION"], 36 | description="", 37 | long_description=open("README.md").read(), 38 | long_description_content_type="text/markdown", 39 | classifiers=[ 40 | "Intended Audience :: Science/Research", 41 | "Development Status :: 3 - Alpha", 42 | "License :: OSI Approved :: Apache Software License", 43 | "Programming Language :: Python :: 3", 44 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 45 | ], 46 | keywords="", 47 | url="https://github.com/allenai/better-promptability", 48 | author="Allen Institute for Artificial Intelligence", 49 | author_email="contact@allenai.org", 50 | license="Apache", 51 | packages=find_packages( 52 | exclude=["*.tests", "*.tests.*", "tests.*", "tests"], 53 | ), 54 | install_requires=read_requirements("requirements.txt"), 55 | python_requires=">=3.7, <3.8", # restriction by promptsource 56 | ) 57 | -------------------------------------------------------------------------------- /tango.yml: -------------------------------------------------------------------------------- 1 | include_package: 2 | - better_promptability 3 | log_level: info 4 | -------------------------------------------------------------------------------- /test_fixtures/configs/check.jsonnet: -------------------------------------------------------------------------------- 1 | local config = { 2 | "type": "default", 3 | "seed": 100, 4 | "gpus": 1, 5 | "fp16": false, 6 | }; 7 | local model = "google/t5-small-lm-adapt"; 8 | local dataset_name = "story_cloze"; 9 | local subset_name = "2016"; 10 | local template_name = "Answer_Given_options_score_eval"; 11 | 12 | { 13 | "steps": { 14 | "output_model": { 15 | "type": "train_step", 16 | "config": config, 17 | "trainer": { 18 | "type": "default", 19 | "max_epochs": 1, 20 | "gradient_clip_val": 1.0, 21 | "accumulate_grad_batches": 1.0, 22 | "log_every_n_steps": 3, 23 | "logger": [ 24 | {"type": "pytorch_lightning::TensorBoardLogger"}, 25 | ], 26 | "callbacks": [ 27 | "pytorch_lightning::ModelCheckpoint", 28 | "my_logger", 29 | ], 30 | "replace_sampler_ddp": false, 31 | }, 32 | "datamodule": { 33 | "type": "t0", 34 | "dataset_name": dataset_name, 35 | "subset_name": subset_name, 36 | "template_name": template_name, 37 | "subsample_indices_file": "data/green_training_indices_16shot_100seed.pkl", 38 | "data_dir": "data/" + dataset_name + "_" + subset_name + "_" + template_name, 39 | "transformer_model": model, 40 | "num_prefix": 20, 41 | }, 42 | "model": { 43 | "transformer_model": model, 44 | "optimizer": { 45 | "type": "adafactor", 46 | "lr": 0.001, 47 | "scale_parameter": false, 48 | "relative_step": false, 49 | }, 50 | "weight_decay": 1e-5, 51 | } 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /test_fixtures/configs/d4_dev.jsonnet: -------------------------------------------------------------------------------- 1 | local config = { 2 | "type": "default", 3 | "seed": 100, 4 | "gpus": null, 5 | "fp16": false, 6 | }; 7 | local model = "google/t5-small-lm-adapt"; 8 | local task_name = "openbookqa_main_choices"; 9 | 10 | { 11 | "steps": { 12 | "output_model": { 13 | "type": "train_step", 14 | "config": config, 15 | "trainer": { 16 | "type": "default", 17 | "max_epochs": 1, 18 | "gradient_clip_val": 1.0, 19 | "accumulate_grad_batches": 1.0, 20 | "log_every_n_steps": 3, 21 | "logger": [ 22 | {"type": "pytorch_lightning::TensorBoardLogger"}, 23 | ], 24 | "callbacks": [ 25 | "pytorch_lightning::ModelCheckpoint", 26 | "my_logger", 27 | ], 28 | "replace_sampler_ddp": false, 29 | }, 30 | "datamodule": { 31 | "type": "t0", 32 | "mixture_name": "d4_dev", 33 | "task_name": task_name, 34 | "data_dir": "test_fixtures/data", 35 | "t0_data_cache": "test_fixtures/data/cache", 36 | "transformer_model": model, 37 | "num_prefix": 1, 38 | }, 39 | "model": { 40 | "transformer_model": model, 41 | "optimizer": { 42 | "type": "adafactor", 43 | "lr": 0.001, 44 | "scale_parameter": false, 45 | "relative_step": false, 46 | }, 47 | "weight_decay": 1e-5, 48 | } 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /test_fixtures/configs/d4_train.jsonnet: -------------------------------------------------------------------------------- 1 | local config = { 2 | "type": "default", 3 | "seed": 100, 4 | "gpus": null, 5 | "fp16": false, 6 | }; 7 | local model = "google/t5-small-lm-adapt"; 8 | local task_name = "adversarial_qa_dbert_based_on"; 9 | 10 | { 11 | "steps": { 12 | "output_model": { 13 | "type": "train_step", 14 | "config": config, 15 | "trainer": { 16 | "type": "default", 17 | "max_epochs": 1, 18 | "gradient_clip_val": 1.0, 19 | "accumulate_grad_batches": 1.0, 20 | "log_every_n_steps": 3, 21 | "logger": [ 22 | {"type": "pytorch_lightning::TensorBoardLogger"}, 23 | ], 24 | "callbacks": [ 25 | "pytorch_lightning::ModelCheckpoint", 26 | "my_logger", 27 | ], 28 | "replace_sampler_ddp": false, 29 | }, 30 | "datamodule": { 31 | "type": "t0", 32 | "mixture_name": "d4_train", 33 | "task_name": task_name, 34 | "data_dir": "test_fixtures/data", 35 | "t0_data_cache": "test_fixtures/data/cache", 36 | "transformer_model": model, 37 | "num_prefix": 1, 38 | }, 39 | "model": { 40 | "transformer_model": model, 41 | "optimizer": { 42 | "type": "adafactor", 43 | "lr": 0.001, 44 | "scale_parameter": false, 45 | "relative_step": false, 46 | }, 47 | "weight_decay": 1e-5, 48 | } 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /test_fixtures/configs/green.jsonnet: -------------------------------------------------------------------------------- 1 | local config = { 2 | "type": "default", 3 | "seed": 100, 4 | "gpus": null, 5 | "fp16": false, 6 | }; 7 | local model = "google/t5-small-lm-adapt"; 8 | local task_name = "hellaswag_complete_first_then_score_eval"; 9 | 10 | { 11 | "steps": { 12 | "output_model": { 13 | "type": "train_step", 14 | "config": config, 15 | "trainer": { 16 | "type": "default", 17 | "max_epochs": 1, 18 | "gradient_clip_val": 1.0, 19 | "accumulate_grad_batches": 1.0, 20 | "log_every_n_steps": 3, 21 | "logger": [ 22 | {"type": "pytorch_lightning::TensorBoardLogger"}, 23 | ], 24 | "callbacks": [ 25 | "pytorch_lightning::ModelCheckpoint", 26 | "my_logger", 27 | ], 28 | "replace_sampler_ddp": false, 29 | }, 30 | "datamodule": { 31 | "type": "t0", 32 | "mixture_name": "green", 33 | "task_name": task_name, 34 | "data_dir": "test_fixtures/data", 35 | "t0_data_cache": "test_fixtures/data/processed_cache", 36 | "transformer_model": model, 37 | "num_prefix": 1, 38 | }, 39 | "model": { 40 | "transformer_model": model, 41 | "optimizer": { 42 | "type": "adafactor", 43 | "lr": 0.001, 44 | "scale_parameter": false, 45 | "relative_step": false, 46 | }, 47 | "weight_decay": 1e-5, 48 | } 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /test_fixtures/data/cache/adversarial_qa_dbert_based_on/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "validation"]} -------------------------------------------------------------------------------- /test_fixtures/data/cache/adversarial_qa_dbert_based_on/train/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/adversarial_qa_dbert_based_on/train/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/adversarial_qa_dbert_based_on/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "inputs": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs_pretokenized": { 21 | "dtype": "string", 22 | "id": null, 23 | "_type": "Value" 24 | }, 25 | "targets": { 26 | "feature": { 27 | "dtype": "int64", 28 | "id": null, 29 | "_type": "Value" 30 | }, 31 | "length": -1, 32 | "id": null, 33 | "_type": "Sequence" 34 | }, 35 | "targets_pretokenized": { 36 | "dtype": "string", 37 | "id": null, 38 | "_type": "Value" 39 | } 40 | }, 41 | "homepage": "", 42 | "license": "", 43 | "post_processed": null, 44 | "post_processing_size": null, 45 | "size_in_bytes": null, 46 | "splits": null, 47 | "supervised_keys": null, 48 | "task_templates": null, 49 | "version": null 50 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/adversarial_qa_dbert_based_on/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "3345d4059c4df582", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/adversarial_qa_dbert_based_on/validation/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/adversarial_qa_dbert_based_on/validation/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/adversarial_qa_dbert_based_on/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "inputs": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs_pretokenized": { 21 | "dtype": "string", 22 | "id": null, 23 | "_type": "Value" 24 | }, 25 | "targets": { 26 | "feature": { 27 | "dtype": "int64", 28 | "id": null, 29 | "_type": "Value" 30 | }, 31 | "length": -1, 32 | "id": null, 33 | "_type": "Sequence" 34 | }, 35 | "targets_pretokenized": { 36 | "dtype": "string", 37 | "id": null, 38 | "_type": "Value" 39 | } 40 | }, 41 | "homepage": "", 42 | "license": "", 43 | "post_processed": null, 44 | "post_processing_size": null, 45 | "size_in_bytes": null, 46 | "splits": null, 47 | "supervised_keys": null, 48 | "task_templates": null, 49 | "version": null 50 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/adversarial_qa_dbert_based_on/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "3345d4059c4df582", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "validation", "test"]} -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/test/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/test/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "idx": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "is_correct": { 36 | "dtype": "bool", 37 | "id": null, 38 | "_type": "Value" 39 | }, 40 | "targets": { 41 | "feature": { 42 | "dtype": "int64", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "targets_pretokenized": { 51 | "dtype": "string", 52 | "id": null, 53 | "_type": "Value" 54 | }, 55 | "weight": { 56 | "dtype": "float64", 57 | "id": null, 58 | "_type": "Value" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "b64b3db845944006", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/train/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/train/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "idx": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "is_correct": { 36 | "dtype": "bool", 37 | "id": null, 38 | "_type": "Value" 39 | }, 40 | "targets": { 41 | "feature": { 42 | "dtype": "int64", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "targets_pretokenized": { 51 | "dtype": "string", 52 | "id": null, 53 | "_type": "Value" 54 | }, 55 | "weight": { 56 | "dtype": "float64", 57 | "id": null, 58 | "_type": "Value" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "b64b3db845944006", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/validation/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/validation/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "idx": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "is_correct": { 36 | "dtype": "bool", 37 | "id": null, 38 | "_type": "Value" 39 | }, 40 | "targets": { 41 | "feature": { 42 | "dtype": "int64", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "targets_pretokenized": { 51 | "dtype": "string", 52 | "id": null, 53 | "_type": "Value" 54 | }, 55 | "weight": { 56 | "dtype": "float64", 57 | "id": null, 58 | "_type": "Value" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_complete_first_then_score_eval/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "b64b3db845944006", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "validation", "test"]} -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/test/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/test/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "idx": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "is_correct": { 36 | "dtype": "bool", 37 | "id": null, 38 | "_type": "Value" 39 | }, 40 | "targets": { 41 | "feature": { 42 | "dtype": "int64", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "targets_pretokenized": { 51 | "dtype": "string", 52 | "id": null, 53 | "_type": "Value" 54 | }, 55 | "weight": { 56 | "dtype": "float64", 57 | "id": null, 58 | "_type": "Value" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "3464b53cdc7bb99d", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/train/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/train/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "idx": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "is_correct": { 36 | "dtype": "bool", 37 | "id": null, 38 | "_type": "Value" 39 | }, 40 | "targets": { 41 | "feature": { 42 | "dtype": "int64", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "targets_pretokenized": { 51 | "dtype": "string", 52 | "id": null, 53 | "_type": "Value" 54 | }, 55 | "weight": { 56 | "dtype": "float64", 57 | "id": null, 58 | "_type": "Value" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "3464b53cdc7bb99d", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/validation/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/validation/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "idx": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "is_correct": { 36 | "dtype": "bool", 37 | "id": null, 38 | "_type": "Value" 39 | }, 40 | "targets": { 41 | "feature": { 42 | "dtype": "int64", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "targets_pretokenized": { 51 | "dtype": "string", 52 | "id": null, 53 | "_type": "Value" 54 | }, 55 | "weight": { 56 | "dtype": "float64", 57 | "id": null, 58 | "_type": "Value" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/hellaswag_if_begins_how_continues_score_eval/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "3464b53cdc7bb99d", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/openbookqa_main_choices/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "validation"]} -------------------------------------------------------------------------------- /test_fixtures/data/cache/openbookqa_main_choices/train/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/openbookqa_main_choices/train/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/openbookqa_main_choices/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "answer_choices": { 11 | "feature": { 12 | "dtype": "string", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "targets": { 36 | "feature": { 37 | "dtype": "int64", 38 | "id": null, 39 | "_type": "Value" 40 | }, 41 | "length": -1, 42 | "id": null, 43 | "_type": "Sequence" 44 | }, 45 | "targets_pretokenized": { 46 | "dtype": "string", 47 | "id": null, 48 | "_type": "Value" 49 | } 50 | }, 51 | "homepage": "", 52 | "license": "", 53 | "post_processed": null, 54 | "post_processing_size": null, 55 | "size_in_bytes": null, 56 | "splits": null, 57 | "supervised_keys": null, 58 | "task_templates": null, 59 | "version": null 60 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/openbookqa_main_choices/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "d11f1f3b47d2fc54", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/openbookqa_main_choices/validation/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/openbookqa_main_choices/validation/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/openbookqa_main_choices/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "answer_choices": { 11 | "feature": { 12 | "dtype": "string", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "targets": { 36 | "feature": { 37 | "dtype": "int64", 38 | "id": null, 39 | "_type": "Value" 40 | }, 41 | "length": -1, 42 | "id": null, 43 | "_type": "Sequence" 44 | }, 45 | "targets_pretokenized": { 46 | "dtype": "string", 47 | "id": null, 48 | "_type": "Value" 49 | } 50 | }, 51 | "homepage": "", 52 | "license": "", 53 | "post_processed": null, 54 | "post_processing_size": null, 55 | "size_in_bytes": null, 56 | "splits": null, 57 | "supervised_keys": null, 58 | "task_templates": null, 59 | "version": null 60 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/openbookqa_main_choices/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "d11f1f3b47d2fc54", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["test", "validation"]} -------------------------------------------------------------------------------- /test_fixtures/data/cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/test/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/test/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "answer_choices": { 11 | "feature": { 12 | "dtype": "string", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "targets": { 36 | "feature": { 37 | "dtype": "int64", 38 | "id": null, 39 | "_type": "Value" 40 | }, 41 | "length": -1, 42 | "id": null, 43 | "_type": "Sequence" 44 | }, 45 | "targets_pretokenized": { 46 | "dtype": "string", 47 | "id": null, 48 | "_type": "Value" 49 | } 50 | }, 51 | "homepage": "", 52 | "license": "", 53 | "post_processed": null, 54 | "post_processing_size": null, 55 | "size_in_bytes": null, 56 | "splits": null, 57 | "supervised_keys": null, 58 | "task_templates": null, 59 | "version": null 60 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "a8af6ce55b94bab6", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/validation/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/validation/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "answer_choices": { 11 | "feature": { 12 | "dtype": "string", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs": { 21 | "feature": { 22 | "dtype": "int64", 23 | "id": null, 24 | "_type": "Value" 25 | }, 26 | "length": -1, 27 | "id": null, 28 | "_type": "Sequence" 29 | }, 30 | "inputs_pretokenized": { 31 | "dtype": "string", 32 | "id": null, 33 | "_type": "Value" 34 | }, 35 | "targets": { 36 | "feature": { 37 | "dtype": "int64", 38 | "id": null, 39 | "_type": "Value" 40 | }, 41 | "length": -1, 42 | "id": null, 43 | "_type": "Sequence" 44 | }, 45 | "targets_pretokenized": { 46 | "dtype": "string", 47 | "id": null, 48 | "_type": "Value" 49 | } 50 | }, 51 | "homepage": "", 52 | "license": "", 53 | "post_processed": null, 54 | "post_processing_size": null, 55 | "size_in_bytes": null, 56 | "splits": null, 57 | "supervised_keys": null, 58 | "task_templates": null, 59 | "version": null 60 | } -------------------------------------------------------------------------------- /test_fixtures/data/cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "a8af6ce55b94bab6", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "validation", "test"]} -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/test/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/test/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "inputs": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs_pretokenized": { 21 | "dtype": "string", 22 | "id": null, 23 | "_type": "Value" 24 | }, 25 | "targets": { 26 | "feature": { 27 | "feature": { 28 | "dtype": "int64", 29 | "id": null, 30 | "_type": "Value" 31 | }, 32 | "length": -1, 33 | "id": null, 34 | "_type": "Sequence" 35 | }, 36 | "length": -1, 37 | "id": null, 38 | "_type": "Sequence" 39 | }, 40 | "targets_pretokenized": { 41 | "feature": { 42 | "dtype": "string", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "is_correct": { 51 | "feature": { 52 | "dtype": "bool", 53 | "id": null, 54 | "_type": "Value" 55 | }, 56 | "length": -1, 57 | "id": null, 58 | "_type": "Sequence" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "17d0a1c46ce6e475", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/train/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/train/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "inputs": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs_pretokenized": { 21 | "dtype": "string", 22 | "id": null, 23 | "_type": "Value" 24 | }, 25 | "targets": { 26 | "feature": { 27 | "feature": { 28 | "dtype": "int64", 29 | "id": null, 30 | "_type": "Value" 31 | }, 32 | "length": -1, 33 | "id": null, 34 | "_type": "Sequence" 35 | }, 36 | "length": -1, 37 | "id": null, 38 | "_type": "Sequence" 39 | }, 40 | "targets_pretokenized": { 41 | "feature": { 42 | "dtype": "string", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "is_correct": { 51 | "feature": { 52 | "dtype": "bool", 53 | "id": null, 54 | "_type": "Value" 55 | }, 56 | "length": -1, 57 | "id": null, 58 | "_type": "Sequence" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "17d0a1c46ce6e475", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/validation/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/validation/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "inputs": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs_pretokenized": { 21 | "dtype": "string", 22 | "id": null, 23 | "_type": "Value" 24 | }, 25 | "targets": { 26 | "feature": { 27 | "feature": { 28 | "dtype": "int64", 29 | "id": null, 30 | "_type": "Value" 31 | }, 32 | "length": -1, 33 | "id": null, 34 | "_type": "Sequence" 35 | }, 36 | "length": -1, 37 | "id": null, 38 | "_type": "Sequence" 39 | }, 40 | "targets_pretokenized": { 41 | "feature": { 42 | "dtype": "string", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "is_correct": { 51 | "feature": { 52 | "dtype": "bool", 53 | "id": null, 54 | "_type": "Value" 55 | }, 56 | "length": -1, 57 | "id": null, 58 | "_type": "Sequence" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/hellaswag_complete_first_then_score_eval/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "17d0a1c46ce6e475", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["validation", "train"]} -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/train/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/processed_cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/train/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "inputs": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs_pretokenized": { 21 | "dtype": "string", 22 | "id": null, 23 | "_type": "Value" 24 | }, 25 | "targets": { 26 | "feature": { 27 | "feature": { 28 | "dtype": "int64", 29 | "id": null, 30 | "_type": "Value" 31 | }, 32 | "length": -1, 33 | "id": null, 34 | "_type": "Sequence" 35 | }, 36 | "length": -1, 37 | "id": null, 38 | "_type": "Sequence" 39 | }, 40 | "targets_pretokenized": { 41 | "feature": { 42 | "dtype": "string", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "is_correct": { 51 | "feature": { 52 | "dtype": "bool", 53 | "id": null, 54 | "_type": "Value" 55 | }, 56 | "length": -1, 57 | "id": null, 58 | "_type": "Sequence" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "fc0e9aab7979b992", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/validation/dataset.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/test_fixtures/data/processed_cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/validation/dataset.arrow -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/validation/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": null, 3 | "citation": "", 4 | "config_name": null, 5 | "dataset_size": null, 6 | "description": "", 7 | "download_checksums": null, 8 | "download_size": null, 9 | "features": { 10 | "inputs": { 11 | "feature": { 12 | "dtype": "int64", 13 | "id": null, 14 | "_type": "Value" 15 | }, 16 | "length": -1, 17 | "id": null, 18 | "_type": "Sequence" 19 | }, 20 | "inputs_pretokenized": { 21 | "dtype": "string", 22 | "id": null, 23 | "_type": "Value" 24 | }, 25 | "targets": { 26 | "feature": { 27 | "feature": { 28 | "dtype": "int64", 29 | "id": null, 30 | "_type": "Value" 31 | }, 32 | "length": -1, 33 | "id": null, 34 | "_type": "Sequence" 35 | }, 36 | "length": -1, 37 | "id": null, 38 | "_type": "Sequence" 39 | }, 40 | "targets_pretokenized": { 41 | "feature": { 42 | "dtype": "string", 43 | "id": null, 44 | "_type": "Value" 45 | }, 46 | "length": -1, 47 | "id": null, 48 | "_type": "Sequence" 49 | }, 50 | "is_correct": { 51 | "feature": { 52 | "dtype": "bool", 53 | "id": null, 54 | "_type": "Value" 55 | }, 56 | "length": -1, 57 | "id": null, 58 | "_type": "Sequence" 59 | } 60 | }, 61 | "homepage": "", 62 | "license": "", 63 | "post_processed": null, 64 | "post_processing_size": null, 65 | "size_in_bytes": null, 66 | "splits": null, 67 | "supervised_keys": null, 68 | "task_templates": null, 69 | "version": null 70 | } -------------------------------------------------------------------------------- /test_fixtures/data/processed_cache/story_cloze_2016_Story_Continuation_and_Options_score_eval/validation/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "dataset.arrow" 5 | } 6 | ], 7 | "_fingerprint": "fc0e9aab7979b992", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_indexes": {}, 12 | "_output_all_columns": false, 13 | "_split": null 14 | } -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/tests/__init__.py -------------------------------------------------------------------------------- /tests/configs_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tango.common import Params 4 | 5 | 6 | def test_few_shot_baseline_all(): 7 | os.environ["CKPT"] = "null" 8 | d = Params.from_file("configs/fewshot_eval_all_green.jsonnet").as_dict() 9 | del os.environ["CKPT"] 10 | assert "result_anli_GPT_3_style_r1_score_eval" in d["steps"] 11 | assert "aggregated_results" in d["steps"] 12 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/mixer_dataset_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from better_promptability.data.mixer_dataset import MixerDataset 4 | 5 | 6 | @pytest.fixture 7 | def datasets(): 8 | return [["a1", "a2", "a3"], ["b1", "b2", "b3", "b4", "b5", "b6", "b7"]] 9 | 10 | 11 | def test_mixer_dataset(datasets): 12 | mixer = MixerDataset(datasets) 13 | assert len(mixer) == 10 14 | assert [x for x in mixer] == [x for dataset in datasets for x in dataset] 15 | 16 | 17 | def test_mixer_dataset_with_size_limit(datasets): 18 | mixer = MixerDataset(datasets, sampling_cap=3) 19 | assert len(mixer) == 6 20 | assert [x for x in mixer][:3] == ["a1", "a2", "a3"] 21 | for x in [x for x in mixer][3:]: 22 | assert x in datasets[1] 23 | 24 | # Make sure we get to all instances in all datasets if we call `resample` enough times. 25 | seen = set(iter(mixer)) 26 | for _ in range(2): 27 | mixer.resample() 28 | for x in mixer: 29 | seen.add(x) 30 | 31 | assert seen == set((x for dataset in datasets for x in dataset)) 32 | -------------------------------------------------------------------------------- /tests/data/t0_data_module_test.py: -------------------------------------------------------------------------------- 1 | from better_promptability.data.config import Config 2 | from better_promptability.data import T0Module 3 | from better_promptability.common.testing import BetterPromptabilityTestCase 4 | 5 | 6 | class T0ModuleTest(BetterPromptabilityTestCase): 7 | def test_t0_module_green(self): 8 | t0 = T0Module( 9 | config=Config(), 10 | data_dir=str(self.FIXTURES_ROOT / "data"), 11 | num_prefix=1, 12 | transformer_model="google/t5-small-lm-adapt", 13 | mixture_name="green", 14 | task_name="hellaswag_complete_first_then_score_eval", 15 | t0_data_cache=str(self.FIXTURES_ROOT / "data" / "processed_cache"), 16 | ) 17 | 18 | t0.setup() 19 | data = t0.load() 20 | assert "train" in data 21 | 22 | train_batch = list(t0.train_dataloader())[0] 23 | assert train_batch["target_ids"].dim() == 2 24 | 25 | val_batch = list(t0.val_dataloader()[0])[0] 26 | assert val_batch["target_ids"].dim() == 3 27 | 28 | def test_t0_module_green_story_cloze(self): 29 | 30 | # Story_cloze special case. 31 | 32 | t0 = T0Module( 33 | config=Config(), 34 | data_dir=str(self.FIXTURES_ROOT / "data"), 35 | num_prefix=1, 36 | transformer_model="google/t5-small-lm-adapt", 37 | mixture_name="green", 38 | task_name="story_cloze_2016_Story_Continuation_and_Options_score_eval", 39 | t0_data_cache=str(self.FIXTURES_ROOT / "data" / "processed_cache"), 40 | ) 41 | 42 | t0.setup() 43 | data = t0.load() 44 | assert "train" in data 45 | 46 | train_batch = list(t0.train_dataloader())[0] 47 | assert train_batch["target_ids"].dim() == 2 48 | 49 | val_batch = list(t0.val_dataloader()[0])[0] 50 | assert val_batch["target_ids"].dim() == 3 51 | 52 | def test_t0_module_d4_train(self): 53 | t0 = T0Module( 54 | config=Config(), 55 | data_dir=str(self.FIXTURES_ROOT / "data"), 56 | num_prefix=1, 57 | transformer_model="google/t5-small-lm-adapt", 58 | mixture_name="d4_train", 59 | task_name="adversarial_qa_dbert_based_on", 60 | t0_data_cache=str(self.FIXTURES_ROOT / "data" / "cache"), 61 | ) 62 | 63 | t0.setup() 64 | data = t0.load() 65 | assert "train" in data 66 | 67 | train_batch = list(t0.train_dataloader())[0] 68 | assert train_batch["target_ids"].dim() == 2 69 | 70 | val_batch = list(t0.val_dataloader()[0])[0] 71 | assert val_batch["target_ids"].dim() == 2 72 | 73 | def test_t0_module_d4_dev(self): 74 | t0 = T0Module( 75 | config=Config(), 76 | data_dir=str(self.FIXTURES_ROOT / "data"), 77 | num_prefix=1, 78 | transformer_model="google/t5-small-lm-adapt", 79 | mixture_name="d4_dev", 80 | task_name="openbookqa_main_choices", 81 | t0_data_cache=str(self.FIXTURES_ROOT / "data" / "cache"), 82 | ) 83 | 84 | t0.setup() 85 | data = t0.load() 86 | assert "train" in data 87 | 88 | train_batch = list(t0.train_dataloader())[0] 89 | assert train_batch["target_ids"].dim() == 2 90 | 91 | val_batch = list(t0.val_dataloader()[0])[0] 92 | assert val_batch["target_ids"].dim() == 3 93 | -------------------------------------------------------------------------------- /tests/hello_test.py: -------------------------------------------------------------------------------- 1 | def test_hello(): 2 | print("Hello, World!") 3 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/tests/modules/__init__.py -------------------------------------------------------------------------------- /tests/modules/transformer_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers.models import t5 as hf_t5 3 | from better_promptability.modules.transformer import Transformer 4 | 5 | 6 | @pytest.fixture(scope="module") 7 | def model_name(): 8 | return "google/t5-small-lm-adapt" 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def tokenizer(model_name): 13 | return hf_t5.T5Tokenizer.from_pretrained(model_name) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "task", 18 | [ 19 | "seq2seq-lm", 20 | ], 21 | ) 22 | def test_transformer(task: str, model_name: str, tokenizer: hf_t5.T5Tokenizer): 23 | 24 | model = Transformer(model_name, task=task) 25 | 26 | input_ids = tokenizer( 27 | ["The walks in park", "The barked"], 28 | return_tensors="pt", 29 | padding=True, 30 | ).input_ids 31 | assert input_ids.tolist() == [ 32 | [37, 32099, 10681, 16, 32098, 2447, 1], 33 | [37, 32099, 1207, 5100, 1, 0, 0], 34 | ] 35 | 36 | attention_mask = ~(input_ids == 0) 37 | 38 | labels = tokenizer( 39 | [" cute dog the ", " dog"], 40 | return_tensors="pt", 41 | padding=True, 42 | ).input_ids 43 | assert labels.tolist() == [ 44 | [32099, 5295, 1782, 32098, 8, 32097, 1], 45 | [32099, 1782, 1, 0, 0, 0, 0], 46 | ] 47 | 48 | decoder_attention_mask = ~(labels == 0) 49 | 50 | output = model.forward( 51 | input_ids, 52 | attention_mask=attention_mask, 53 | labels=labels, 54 | decoder_attention_mask=decoder_attention_mask, 55 | ) 56 | 57 | assert output.logits is not None 58 | -------------------------------------------------------------------------------- /tests/steps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/better-promptability/5cb1e33c9988f6f973e92a1c78da8291fe55df64/tests/steps/__init__.py -------------------------------------------------------------------------------- /tests/steps/process_dataset_test.py: -------------------------------------------------------------------------------- 1 | from better_promptability.steps.process_dataset import ProcessDataset 2 | from better_promptability.common.testing import BetterPromptabilityTestCase 3 | 4 | 5 | class ProcessDatasetTest(BetterPromptabilityTestCase): 6 | def test_process_dataset(self): 7 | step = ProcessDataset() 8 | result = step.run( 9 | old_data_path=str( 10 | self.FIXTURES_ROOT / "data" / "cache" / "hellaswag_complete_first_then_score_eval" 11 | ), 12 | new_data_path=str( 13 | self.FIXTURES_ROOT 14 | / "data" 15 | / "processed_cache" 16 | / "hellaswag_complete_first_then_score_eval" 17 | ), 18 | process_if_exists=True, 19 | ) 20 | 21 | assert len(result["train"]) == 7 22 | assert len(result["train"][0]["targets"]) == 4 23 | assert len(result["train"][0]["targets_pretokenized"]) == 4 24 | assert len(result["train"][0]["is_correct"]) == 4 25 | -------------------------------------------------------------------------------- /tests/steps/process_story_cloze_test.py: -------------------------------------------------------------------------------- 1 | from better_promptability.steps.process_story_cloze import ProcessStoryCloze 2 | from better_promptability.common.testing import BetterPromptabilityTestCase 3 | 4 | 5 | class ProcessStoryClozeTest(BetterPromptabilityTestCase): 6 | def test_process_story_cloze(self): 7 | step = ProcessStoryCloze() 8 | result = step.run( 9 | old_data_path=str( 10 | self.FIXTURES_ROOT 11 | / "data" 12 | / "cache" 13 | / "story_cloze_2016_Story_Continuation_and_Options_score_eval" 14 | ), 15 | new_data_path=str( 16 | self.FIXTURES_ROOT 17 | / "data" 18 | / "processed_cache" 19 | / "story_cloze_2016_Story_Continuation_and_Options_score_eval" 20 | ), 21 | process_if_exists=True, 22 | ) 23 | 24 | assert len(result["train"]) == 28 25 | 26 | assert len(result["train"][0]["targets"]) == 2 27 | assert len(result["train"][0]["targets_pretokenized"]) == 2 28 | assert len(result["train"][0]["is_correct"]) == 2 29 | 30 | assert "validation" in result 31 | assert "test" not in result 32 | --------------------------------------------------------------------------------