├── .flake8 ├── .github ├── pull_request_template.md └── workflows │ ├── base_ci.yml │ ├── ci.yml │ └── release.yml ├── .gitignore ├── .gitmodules ├── .markdownlint.yaml ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── examples ├── create_synthetic_data_example.py ├── create_transform_data_example.py └── mistral_qlora_finetune_example.py ├── mypy.ini ├── prompt2model ├── __init__.py ├── dataset_generator │ ├── __init__.py │ ├── base.py │ ├── mock.py │ ├── prompt_based.py │ ├── prompt_template.py │ └── readme.md ├── dataset_processor │ ├── __init__.py │ ├── base.py │ ├── mock.py │ ├── readme.md │ └── textualize.py ├── dataset_retriever │ ├── __init__.py │ ├── base.py │ ├── column_selection_prompt.py │ ├── description_dataset_retriever.py │ ├── mock.py │ ├── readme.md │ ├── reranking_prompt.py │ ├── run_dataset_retriever.py │ └── task_expansion_prompt.py ├── dataset_transformer │ ├── __init__.py │ ├── base.py │ ├── prompt_based.py │ └── prompt_template.py ├── demo_creator │ ├── __init__.py │ ├── create.py │ ├── mock.py │ └── readme.md ├── model_evaluator │ ├── __init__.py │ ├── base.py │ ├── mock.py │ ├── readme.md │ └── seq2seq.py ├── model_executor │ ├── __init__.py │ ├── base.py │ ├── generate.py │ ├── mock.py │ └── readme.md ├── model_retriever │ ├── __init__.py │ ├── base.py │ ├── description_based_retriever.py │ ├── generate_hypothetical_document.py │ ├── mock.py │ ├── readme.md │ └── run_model_retriever.py ├── model_trainer │ ├── __init__.py │ ├── base.py │ ├── callback.py │ ├── generate.py │ ├── mock.py │ ├── qlora_trainer.py │ └── readme.md ├── param_selector │ ├── __init__.py │ ├── base.py │ ├── mock.py │ └── search_with_optuna.py ├── prompt_parser │ ├── __init__.py │ ├── base.py │ ├── instr_parser.py │ ├── instr_parser_prompt.py │ ├── mock.py │ └── readme.md ├── run_locally.py ├── utils │ ├── __init__.py │ ├── api_tools.py │ ├── config.py │ ├── dataset_utils.py │ ├── dataset_utils_test.py │ ├── logging_utils.py │ ├── parse_responses.py │ ├── retrieve_model_info.py │ ├── rng.py │ └── tevatron_utils │ │ ├── __init__.py │ │ ├── encode.py │ │ └── retrieve.py └── version.py ├── prompt2model_demo.ipynb ├── prompt2model_demo.py ├── prompt_examples.md ├── pyproject.toml ├── scripts └── dataset_index │ ├── preprocessing.py │ └── retrieve_dataset_info.py ├── test_helpers ├── __init__.py ├── dataset_index_tiny.json ├── mock_api.py ├── mock_retrieval.py ├── model_and_tokenizer.py ├── model_info_tiny │ ├── gpt2.json │ ├── t5-base.json │ └── xlnet-base-cased.json ├── reranking_dataset_index_tiny.json └── test_utils.py └── tests ├── __init__.py ├── dataset_generator_test.py ├── dataset_processor_test.py ├── dataset_retriever_test.py ├── dataset_transformer_test.py ├── demo_creator_test.py ├── model_evaluator_test.py ├── model_executor_for_gpt_test.py ├── model_executor_for_t5_test.py ├── model_retriever_test.py ├── model_trainer_for_gpt_test.py ├── model_trainer_for_t5_test.py ├── param_selector_test.py ├── prompt_parser_test.py ├── run_locally_test.py └── tevatron_utils_test.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18,BLK100,W503 4 | per-file-ignores = prompt2model/dataset_transformer/prompt_template.py:E501 -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Description 4 | 5 | 9 | 10 | # References 11 | 12 | 13 | 14 | - Foo 15 | - Bar 16 | - Baz 17 | 18 | # Blocked by 19 | 20 | 21 | 22 | - NA 23 | - (or link to PRs) 24 | -------------------------------------------------------------------------------- /.github/workflows/base_ci.yml: -------------------------------------------------------------------------------- 1 | name: Reusable Continuous Integration Workflow 2 | on: 3 | workflow_call: 4 | 5 | jobs: 6 | unit-tests: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: [ '3.9', '3.10', '3.11' ] 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Install Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | cache: 'pip' 18 | cache-dependency-path: 'pyproject.toml' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install . 23 | - name: test 24 | run: pytest 25 | format: 26 | runs-on: ubuntu-latest 27 | steps: 28 | - uses: actions/checkout@v3 29 | - name: Install Python 3 30 | uses: actions/setup-python@v3 31 | with: 32 | python-version: 3.9 33 | - name: format 34 | uses: pre-commit/action@v3.0.0 35 | with: 36 | extra_args: black --all-files 37 | isort: 38 | runs-on: ubuntu-latest 39 | steps: 40 | - uses: actions/checkout@v3 41 | - name: Install Python 3 42 | uses: actions/setup-python@v3 43 | with: 44 | python-version: 3.9 45 | - name: isort 46 | uses: pre-commit/action@v3.0.0 47 | with: 48 | extra_args: isort --all-files 49 | lint: 50 | runs-on: ubuntu-latest 51 | steps: 52 | - uses: actions/checkout@v3 53 | - name: Install Python 3 54 | uses: actions/setup-python@v3 55 | with: 56 | python-version: 3.9 57 | - name: lint 58 | uses: pre-commit/action@v3.0.0 59 | with: 60 | extra_args: flake8 --all-files 61 | - name: lint-pep585-compliant 62 | uses: pre-commit/action@v3.0.0 63 | with: 64 | extra_args: upgrade-type-hints --all-files 65 | typecheck: 66 | runs-on: ubuntu-latest 67 | steps: 68 | - uses: actions/checkout@v3 69 | - name: Install Python 3 70 | uses: actions/setup-python@v3 71 | with: 72 | python-version: 3.9 73 | - name: mypy 74 | uses: pre-commit/action@v3.0.0 75 | with: 76 | extra_args: mypy --all-files 77 | markdownlint: 78 | runs-on: ubuntu-latest 79 | steps: 80 | - uses: actions/checkout@v3 81 | - name: Install Node 82 | uses: actions/setup-node@v3 83 | with: 84 | node-version: 18.x 85 | - name: lint 86 | uses: pre-commit/action@v3.0.0 87 | with: 88 | extra_args: markdownlint-cli2 --all-files 89 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Continuous Integration 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: ["**"] 8 | 9 | jobs: 10 | unit-tests: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [ '3.9', '3.10', '3.11' ] 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Install Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | cache: 'pip' 22 | cache-dependency-path: 'pyproject.toml' 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install . 27 | - name: test 28 | run: pytest 29 | format: 30 | runs-on: ubuntu-latest 31 | steps: 32 | - uses: actions/checkout@v3 33 | - name: Install Python 3 34 | uses: actions/setup-python@v3 35 | with: 36 | python-version: 3.9 37 | - name: format 38 | uses: pre-commit/action@v3.0.0 39 | with: 40 | extra_args: black --all-files 41 | isort: 42 | runs-on: ubuntu-latest 43 | steps: 44 | - uses: actions/checkout@v3 45 | - name: Install Python 3 46 | uses: actions/setup-python@v3 47 | with: 48 | python-version: 3.9 49 | - name: isort 50 | uses: pre-commit/action@v3.0.0 51 | with: 52 | extra_args: isort --all-files 53 | lint: 54 | runs-on: ubuntu-latest 55 | steps: 56 | - uses: actions/checkout@v3 57 | - name: Install Python 3 58 | uses: actions/setup-python@v3 59 | with: 60 | python-version: 3.9 61 | - name: lint 62 | uses: pre-commit/action@v3.0.0 63 | with: 64 | extra_args: flake8 --all-files 65 | - name: lint-pep585-compliant 66 | uses: pre-commit/action@v3.0.0 67 | with: 68 | extra_args: upgrade-type-hints --all-files 69 | typecheck: 70 | runs-on: ubuntu-latest 71 | steps: 72 | - uses: actions/checkout@v3 73 | - name: Install Python 3 74 | uses: actions/setup-python@v3 75 | with: 76 | python-version: 3.9 77 | - name: mypy 78 | uses: pre-commit/action@v3.0.0 79 | with: 80 | extra_args: mypy --all-files 81 | markdownlint: 82 | runs-on: ubuntu-latest 83 | steps: 84 | - uses: actions/checkout@v3 85 | - name: Install Node 86 | uses: actions/setup-node@v3 87 | with: 88 | node-version: 18.x 89 | - name: lint 90 | uses: pre-commit/action@v3.0.0 91 | with: 92 | extra_args: markdownlint-cli2 --all-files 93 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release workflow 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v[0123456789].*" 7 | 8 | jobs: 9 | release: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: checkout 13 | uses: actions/checkout@v3 14 | - name: setup python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: "3.x" 18 | - name: build 19 | run: | 20 | python -m pip install --upgrade build hatch 21 | python -m hatch version "${GITHUB_REF_NAME}" 22 | python -m build 23 | - name: publish 24 | uses: pypa/gh-action-pypi-publish@release/v1 25 | with: 26 | password: ${{ secrets.PYPI_API_TOKEN }} 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | build 3 | dist 4 | prompt2model.egg-info 5 | .env 6 | .vscode 7 | .mypy_cache 8 | .pytest_cache 9 | *.pyc 10 | flagged 11 | wandb 12 | logs 13 | temp 14 | tests/logs 15 | tests/wandb 16 | 17 | # Outputs generated by the cli demo 18 | cached_generated_dataset/ 19 | generated_dataset/ 20 | huggingface_data/huggingface_datasets/dataset_index.json 21 | huggingface_data/huggingface_datasets/huggingface_datasets_datafinder_index 22 | huggingface_data/huggingface_datasets/reranking_dataset_index.json 23 | huggingface_data/huggingface_models/ 24 | retrieved_dataset_dict/ 25 | result/ 26 | checkpoint/ 27 | status.yaml 28 | # Outputs generated by the colab demo 29 | trained_model/ 30 | trained_tokenizer/ 31 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/prompt2model/d4cd23d50c8f5edc5755ab487bcee18121d42a4f/.gitmodules -------------------------------------------------------------------------------- /.markdownlint.yaml: -------------------------------------------------------------------------------- 1 | MD013: 2 | line_length: 120 3 | code_blocks: false 4 | MD033: false 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/python/black.git 3 | rev: 22.3.0 4 | hooks: 5 | - id: black 6 | files: '\.py$' 7 | - repo: https://github.com/PyCQA/flake8 8 | rev: 5.0.4 9 | hooks: 10 | - id: flake8 11 | name: flake8 12 | additional_dependencies: 13 | - flake8-absolute-import 14 | - flake8-black>=0.1.1 15 | - flake8-pep585>=0.1.6 16 | entry: flake8 17 | files: '\.py$' 18 | - id: flake8 19 | name: docstring 20 | additional_dependencies: 21 | - flake8-docstrings>=1.6 22 | args: 23 | - --docstring-convention=google 24 | - --select=D 25 | entry: flake8 26 | files: '\.py$' 27 | - id: flake8 28 | name: future-import 29 | additional_dependencies: 30 | - flake8-future-import 31 | args: 32 | - --select= 33 | - --ignore FI58 34 | entry: flake8 35 | files: '\.py$' 36 | - repo: https://github.com/pycqa/isort.git 37 | rev: 5.12.0 38 | hooks: 39 | - id: isort 40 | args: ["--profile", "black"] 41 | 42 | files: '\.py$' 43 | - repo: https://github.com/sondrelg/pep585-upgrade 44 | rev: v1.0.1 45 | hooks: 46 | - id: upgrade-type-hints 47 | files: '\.py$' 48 | - repo: https://github.com/pre-commit/mirrors-mypy 49 | rev: 'v0.981' 50 | hooks: 51 | - id: mypy 52 | additional_dependencies: 53 | - types-requests 54 | files: '\.py$' 55 | - repo: https://github.com/DavidAnson/markdownlint-cli2 56 | rev: v0.5.1 57 | hooks: 58 | - id: markdownlint-cli2 59 | - id: markdownlint-cli2-fix 60 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to prompt2model 2 | 3 | Thanks for your interest in contributing to prompt2model! 4 | We appreciate your support and welcome your 5 | contributions. Here's a guide to help you get started: 6 | 7 | ## Developer Installation 8 | 9 | Installing dependencies: 10 | Run the following command to install all the dependencies after you have forked the repository. 11 | 12 | ```bash 13 | pip install .[dev] 14 | ``` 15 | 16 | If you're a developer, it's recommended to install 17 | pre-commit hooks before starting your development 18 | work. These hooks ensure code formatting, linting, 19 | and type-checking. To install the pre-commit hooks, 20 | run the following command: 21 | 22 | ```bash 23 | pre-commit install 24 | ``` 25 | 26 | Additionally, it's essential to run tests to verify the 27 | functionality of your code. Execute the following 28 | command to run the tests: 29 | 30 | ```bash 31 | pytest 32 | ``` 33 | 34 | ## Contribution Guide 35 | 36 | To contribute to prompt2model, or if you have any questions, 37 | please reach out to us! 38 | 39 | - open an [issue](https://github.com/neulab/prompt2model/issues) or submit a PR 40 | - join us on [discord](https://discord.gg/UCy9csEmFc) 41 | - or reach out to [@vijaytarian](https://twitter.com/vijaytarian) 42 | and [@Chenan3_Zhao](https://twitter.com/Chenan3_Zhao) on Twitter. 43 | 44 | ## Making a Release 45 | 46 | If you have admin privileges for the repository, 47 | you can create a new release of the prompt2model 48 | library. We utilize the 49 | [hatchling](https://github.com/pypa/hatch) build 50 | system, which simplifies the process of making 51 | new releases. 52 | 53 | To create a new release, follow these steps: 54 | 55 | 1. Create a new version tag on GitHub, adhering to 56 | the [semantic versioning](https://semver.org/) guidelines. 57 | 2. Once the tag is created, the continuous integration 58 | (CI) system will automatically build and publish the 59 | new version to PyPI. 60 | 61 | By following these steps, you can effectively make 62 | new releases of the library and contribute to its 63 | ongoing development. 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt2Model - Generate Deployable Models from Instructions 2 | 3 | [![PyPI version](https://badge.fury.io/py/prompt2model.svg)](https://badge.fury.io/py/prompt2model) 4 | ![Github Actions CI tests](https://github.com/neulab/prompt2model/actions/workflows/ci.yml/badge.svg) 5 | [![MIT license](https://img.shields.io/badge/License-MIT-blue.svg)](https://lbesson.mit-license.org/) 6 | [![Discord](https://img.shields.io/discord/1144245269001678959)](https://discord.gg/UCy9csEmFc) 7 | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neulab/prompt2model/blob/main/prompt2model_demo.ipynb) 8 | 9 | `Prompt2Model` is a system that takes a natural 10 | language task description (like the prompts used for 11 | LLMs such as ChatGPT) to train a small 12 | special-purpose model that is conducive for deployment. 13 | 14 | prompt2model_teaser 15 | 16 | ## Quick Start 17 | 18 | ### Notebook 19 | 20 | You can run our demo of `Prompt2Model` through a notebook: 21 | 22 | - [Open Locally](./prompt2model_demo.ipynb) 23 | - [Open in Colab](https://colab.research.google.com/github/neulab/prompt2model/blob/main/prompt2model_demo.ipynb) 24 | 25 | ### Command Line 26 | 27 | You can also run through the command line. 28 | 29 | ```bash 30 | pip install prompt2model 31 | ``` 32 | 33 | `Prompt2Model` supports various platforms such as OpenAI, Anthropic, Huggingface, etc. using [LiteLLM](https://github.com/BerriAI/litellm). 34 | 35 | If you are using OpenAI models (such as the default `gpt-3.5-turbo`), please obtain an 36 | OpenAI API key on their [website](https://platform.openai.com/) then set 37 | the environment variable `OPENAI_API_KEY` to your API key by running 38 | the following command in your terminal: 39 | 40 | ```bash 41 | export OPENAI_API_KEY= 42 | ``` 43 | 44 | [List of all supported providers](https://docs.litellm.ai/docs/providers) 45 | 46 | You can then run 47 | 48 | ```bash 49 | python prompt2model_demo.py 50 | ``` 51 | 52 | to create a small model from a prompt, as shown in 53 | the demo video below. This script must be run on a 54 | device with an internet connection to access the OpenAI 55 | API. For best results, run 56 | this script on a device with a GPU for training 57 | your model. 58 | 59 | ## Demo 60 | 61 | 62 | 63 | ## Tips and Examples to Write a Good Prompt 64 | 65 | You can see the tips and examples to write 66 | a good prompt in [prompt_examples](./prompt_examples.md). 67 | 68 | ## Components 69 | 70 | The `prompt2model` package is composed 71 | of several components, each designed 72 | to fulfill a specific purpose. To gain 73 | a comprehensive understanding of how to 74 | utilize each component effectively, 75 | please consult the `readme.md` file 76 | situated in the directory of the respective 77 | component. These files can be found at 78 | `./prompt2model//readme.md`. 79 | They provide detailed information and 80 | instructions on customizing and maximizing 81 | the functionality of each 82 | component within the package. 83 | 84 | ## Contribution 85 | 86 | If you're interested in contributing to the `prompt2model` project, please 87 | 88 | - refer to [CONTRIBUTING.md](CONTRIBUTING.md) 89 | - open an [issue](https://github.com/neulab/prompt2model/issues) or submit a PR 90 | - join us on [discord](https://discord.gg/UCy9csEmFc) 91 | - or reach out to [@vijaytarian](https://twitter.com/vijaytarian) 92 | and [@Chenan3_Zhao](https://twitter.com/Chenan3_Zhao) on Twitter 93 | 94 | ## Cite 95 | 96 | We have [written a paper describing Prompt2Model in detail](https://arxiv.org/abs/2308.12261). 97 | 98 | If you use Prompt2Model in your research, please cite us! 99 | 100 | If you discuss or use the overall prompt2model framework, please reference 101 | 102 | ```bibtex 103 | @misc{prompt2model, 104 | title={Prompt2Model: Generating Deployable Models from Natural Language Instructions}, 105 | author={Vijay Viswanathan and Chenyang Zhao and Amanda Bertsch and Tongshuang Wu and Graham Neubig}, 106 | year={2023}, 107 | eprint={2308.12261}, 108 | archivePrefix={arXiv}, 109 | primaryClass={cs.CL} 110 | } 111 | ``` 112 | 113 | If you discuss or use our dataset retrieval and transformation tools, please reference 114 | 115 | ```bibtex 116 | @misc{prompt2modeldatatune, 117 | title={Better Synthetic Data by Retrieving and Transforming Existing Datasets}, 118 | author={Saumya Gandhi and Ritu Gala and Vijay Viswanathan and Tongshuang Wu and Graham Neubig}, 119 | year={2024}, 120 | eprint={2404.14361}, 121 | archivePrefix={arXiv}, 122 | primaryClass={cs.CL} 123 | } 124 | ``` 125 | -------------------------------------------------------------------------------- /examples/create_synthetic_data_example.py: -------------------------------------------------------------------------------- 1 | """Example to demonstrate how to create synthetic data based on prompt.""" 2 | 3 | import prompt2model.utils.api_tools as api_tools 4 | from prompt2model.dataset_generator.base import DatasetSplit 5 | from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator 6 | from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType 7 | from prompt2model.utils.api_tools import APIAgent 8 | 9 | if __name__ == "__main__": 10 | # set API keys and create default API agent. 11 | api_tools.default_api_agent = APIAgent( 12 | model_name="gpt-3.5-turbo-16k", max_tokens=8000 13 | ) 14 | 15 | # create prompt based on which transform data will be created 16 | prompt = """ 17 | Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on. 18 | 19 | Here are examples with input questions and context passages, along with their expected outputs: 20 | 21 | input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50." 22 | output="Santa Clara" 23 | 24 | input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)." 25 | output="Vistula River" 26 | 27 | input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries." 28 | output="Europe" 29 | """ # noqa: E501 30 | # parse the prompt to get the instruction and examples 31 | prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION) 32 | prompt_spec.parse_from_prompt(prompt) 33 | print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}") 34 | 35 | # set hyperparams 36 | initial_temperature = 0.4 37 | max_temperature = 1.4 38 | num_samples_total = 20 39 | 40 | # run this pipeline to generate data synthetically based on prompt 41 | unlimited_dataset_generator = PromptBasedDatasetGenerator( 42 | initial_temperature=initial_temperature, 43 | max_temperature=max_temperature, 44 | responses_per_request=3, 45 | ) 46 | generated_dataset = unlimited_dataset_generator.generate_dataset_split( 47 | prompt_spec, num_samples_total, split=DatasetSplit.TRAIN 48 | ) 49 | 50 | # save the final generated dataset to disk 51 | generated_dataset.save_to_disk("demo_generated_dataset") 52 | -------------------------------------------------------------------------------- /examples/create_transform_data_example.py: -------------------------------------------------------------------------------- 1 | """Example of how to create transform data based on a prompt.""" 2 | 3 | import prompt2model.utils.api_tools as api_tools 4 | from prompt2model.dataset_retriever import DescriptionDatasetRetriever 5 | from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType 6 | from prompt2model.utils.api_tools import APIAgent 7 | 8 | if __name__ == "__main__": 9 | # set API keys and create default API agent. 10 | api_tools.default_api_agent = APIAgent( 11 | model_name="gpt-3.5-turbo-16k", max_tokens=8000 12 | ) 13 | 14 | # create prompt based on which transform data will be created 15 | prompt = """ 16 | Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on. 17 | 18 | Here are examples with input questions and context passages, along with their expected outputs: 19 | 20 | input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50." 21 | output="Santa Clara" 22 | 23 | input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)." 24 | output="Vistula River" 25 | 26 | input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries." 27 | output="Europe" 28 | """ # noqa: E501 29 | # parse the prompt to get the instruction and examples 30 | prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION) 31 | prompt_spec.parse_from_prompt(prompt) 32 | print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}") 33 | 34 | # run this pipeline to retrieve relevant datasets, rerank them, 35 | # and transform them based on the prompt 36 | total_num_points_to_transform = 20 37 | retriever = DescriptionDatasetRetriever( 38 | auto_transform_data=True, 39 | total_num_points_to_transform=total_num_points_to_transform, 40 | ) 41 | retrieved_dataset_dict = retriever.retrieve_dataset_dict( 42 | prompt_spec, 43 | ) 44 | 45 | # save the final dataset to disk 46 | if retrieved_dataset_dict is not None: 47 | retrieved_dataset_dict.save_to_disk("demo_retrieved_dataset_dict") 48 | -------------------------------------------------------------------------------- /examples/mistral_qlora_finetune_example.py: -------------------------------------------------------------------------------- 1 | """Example of how to fine-tune a model using the QLoRATrainer class.""" 2 | 3 | import os 4 | 5 | from datasets import load_from_disk 6 | 7 | from prompt2model.model_trainer.qlora_trainer import QLoRATrainer 8 | from prompt2model.utils.dataset_utils import format_train_data, make_combined_datasets 9 | 10 | if __name__ == "__main__": 11 | # First, we load in the datasets we want to fine-tune on. 12 | retrieved_dataset_dict = load_from_disk("demo_retrieved_dataset_dict") 13 | retrieved_dataset = retrieved_dataset_dict["train"] 14 | generated_dataset = load_from_disk("demo_generated_dataset") 15 | dataset_list = [retrieved_dataset, generated_dataset] 16 | 17 | # Next, we combine datasets and create train and eval splits. 18 | train_dataset = make_combined_datasets(dataset_list) 19 | splits = train_dataset.train_test_split(test_size=0.1) 20 | train_dataset = splits["train"] 21 | eval_dataset = splits["test"] 22 | 23 | # At this point, both train_dataset and eval_dataset are datasets with two 24 | # columns: "input_col" and "output_col". 25 | # We need to format them into a single column, "text", for the QLoRATrainer to use. 26 | formatted_train_dataset = format_train_data(train_dataset) 27 | formatted_eval_dataset = format_train_data(eval_dataset) 28 | 29 | # Next, we define the hyperparameters for the QLoRATrainer. 30 | num_epochs = 1 31 | qlora_alpha = 8 32 | qlora_r = 16 33 | qlora_lr = 1e-5 34 | save_folder_path = "qlora_finetuned_model" 35 | load_best_model_at_end = False 36 | 37 | # Next, we create a QLoRATrainer object and call the train_model method. 38 | trainer = QLoRATrainer(model_name="mistralai/Mistral-7B-v0.1", model_max_length=512) 39 | 40 | # `formatted_eval_dataset` contains just one column: "text", 41 | # and is used to calculate eval loss, by checking loss for each next token. 42 | # `eval_dataset` contains two columns: "input_col" and "output_col", 43 | # and is used to calculate eval accuracy, by checking whether the generated output 44 | # exactly matches the expected output. 45 | trained_model, trained_tokenizer = trainer.train_model( 46 | formatted_train_dataset, 47 | formatted_eval_dataset, 48 | eval_dataset, 49 | num_epochs=1, 50 | alpha=qlora_alpha, 51 | r=qlora_r, 52 | lr=qlora_lr, 53 | save_folder_path=save_folder_path, 54 | load_best_model_at_end=load_best_model_at_end, 55 | ) 56 | 57 | # Finally, we save the trained model and tokenizer to disk. 58 | trained_model.save_pretrained(os.path.join(save_folder_path, "demo_final_model")) 59 | trained_tokenizer.save_pretrained( 60 | os.path.join(save_folder_path, "demo_final_tokenizer") 61 | ) 62 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | 4 | [mypy-yaml.*] 5 | ignore_missing_imports = True 6 | 7 | [mypy-termcolor.*] 8 | ignore_missing_imports = True 9 | -------------------------------------------------------------------------------- /prompt2model/__init__.py: -------------------------------------------------------------------------------- 1 | """prompt2model top level package.""" 2 | -------------------------------------------------------------------------------- /prompt2model/dataset_generator/__init__.py: -------------------------------------------------------------------------------- 1 | """Import DatasetGenerator classes.""" 2 | from prompt2model.dataset_generator.base import DatasetGenerator, DatasetSplit 3 | from prompt2model.dataset_generator.mock import MockDatasetGenerator 4 | from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator 5 | 6 | __all__ = ( 7 | "PromptBasedDatasetGenerator", 8 | "MockDatasetGenerator", 9 | "DatasetGenerator", 10 | "DatasetSplit", 11 | ) 12 | -------------------------------------------------------------------------------- /prompt2model/dataset_generator/base.py: -------------------------------------------------------------------------------- 1 | """An interface for dataset generation.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | from abc import ABC, abstractmethod 6 | from enum import Enum 7 | 8 | import datasets 9 | 10 | from prompt2model.prompt_parser import PromptSpec 11 | 12 | 13 | class DatasetSplit(Enum): 14 | """The split of a dataset.""" 15 | 16 | TRAIN = "train" 17 | VAL = "val" 18 | TEST = "test" 19 | 20 | 21 | class DatasetGenerator(ABC): 22 | """A class for generating datasets from a prompt specification.""" 23 | 24 | @abstractmethod 25 | def generate_dataset_split( 26 | self, 27 | prompt_spec: PromptSpec, 28 | num_examples: int, 29 | split: DatasetSplit, 30 | ) -> datasets.Dataset: 31 | """Generate data for a single named split of data. 32 | 33 | Args: 34 | prompt_spec: A prompt spec (containing a system description). 35 | num_examples: Expected number of examples in split. 36 | split: Name of dataset split to generate. 37 | 38 | Returns: 39 | A single dataset split. 40 | """ 41 | 42 | def generate_dataset_dict( 43 | self, 44 | prompt_spec: PromptSpec, 45 | num_examples: dict[DatasetSplit, int], 46 | ) -> datasets.DatasetDict: 47 | """Generate full dataset splits (e.g. train/dev/test) from a prompt. 48 | 49 | Args: 50 | prompt_spec: A prompt specification. 51 | num_examples: Expected number of 52 | examples per split (train/val/test). 53 | 54 | Returns: 55 | A DatasetDict containing train, val, and test splits. 56 | """ 57 | dataset_dict = datasets.DatasetDict( 58 | { 59 | split.value: self.generate_dataset_split(prompt_spec, num, split=split) 60 | for split, num in num_examples.items() 61 | } 62 | ) 63 | 64 | return dataset_dict 65 | -------------------------------------------------------------------------------- /prompt2model/dataset_generator/mock.py: -------------------------------------------------------------------------------- 1 | """A class for generating empty datasets (for testing purposes).""" 2 | 3 | import datasets 4 | 5 | from prompt2model.dataset_generator.base import DatasetGenerator, DatasetSplit 6 | from prompt2model.prompt_parser import PromptSpec 7 | 8 | 9 | class MockDatasetGenerator(DatasetGenerator): 10 | """A class for generating empty datasets (for testing purposes).""" 11 | 12 | def generate_dataset_split( 13 | self, 14 | prompt_spec: PromptSpec, 15 | expected_num_examples: int, 16 | split: DatasetSplit, 17 | ) -> datasets.Dataset: 18 | """Create empty versions of the datasets, for testing. 19 | 20 | Args: 21 | prompt_spec: A prompt specification. 22 | expected_num_examples: Number of examples in split. 23 | split: Name of dataset split to generate. 24 | 25 | Returns: 26 | A single dataset split. 27 | """ 28 | _ = prompt_spec, split # suppress unused variable warnings 29 | col_values = [""] * expected_num_examples 30 | # Construct empty-valued dataframe with length matching expected_num_examples. 31 | return datasets.Dataset.from_dict( 32 | {"input_col": col_values, "output_col": col_values} 33 | ) 34 | -------------------------------------------------------------------------------- /prompt2model/dataset_generator/readme.md: -------------------------------------------------------------------------------- 1 | # Dataset Generator 2 | 3 | ## Overview 4 | 5 | - `DatasetGenerator`: An abstract class to generate datasets. 6 | - `DatasetSplit`: An enumeration class defining dataset types (`TRAIN`, 7 | `VALIDATION`, `TEST`). 8 | - `PromptBasedDatasetGenerator`: A concrete class 9 | for dataset generation using GPT-3.5 API. 10 | 11 | ## Getting Started 12 | 13 | - **Import the Modules**: 14 | 15 | ```python 16 | from prompt2model.dataset_generator import PromptBasedDatasetGenerator, DatasetSplit 17 | from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType 18 | ``` 19 | 20 | - **Setup API Key**: 21 | 22 | Set an API key as an environment variable. For instance, if using OpenAI: 23 | 24 | ```bash 25 | export OPENAI_API_KEY="" 26 | ``` 27 | 28 | - **Parse the Prompt**: 29 | 30 | ```python 31 | prompt_spec = PromptBasedInstructionParser(task_type=TaskType.) 32 | # Refer the document string of DatasetSplit for more details. 33 | prompt = "" 34 | prompt_spec.parse_from_prompt(prompt) 35 | ``` 36 | 37 | Or you can mock a `PromptSpec` object, as 38 | shown in [PromptParser](./../prompt_parser/readme.md). 39 | 40 | **Generate Dataset**: 41 | 42 | For a specific split: 43 | 44 | ```python 45 | expected_num_examples = 100 46 | split = DatasetSplit.TRAIN 47 | dataset = dataset_generator.generate_dataset_split( 48 | prompt_spec 49 | , expected_num_examples 50 | , split 51 | ) 52 | ``` 53 | 54 | Or, for multiple splits: 55 | 56 | ```python 57 | expected_num_examples = { 58 | DatasetSplit.TRAIN: 1000, 59 | DatasetSplit.VALIDATION: 100, 60 | DatasetSplit.TEST: 200 61 | } 62 | dataset_dict = dataset_generator.generate_dataset_dict(prompt_spec, expected_num_examples) 63 | ``` 64 | -------------------------------------------------------------------------------- /prompt2model/dataset_processor/__init__.py: -------------------------------------------------------------------------------- 1 | """Import DatasetProcessor classes.""" 2 | from prompt2model.dataset_processor.base import BaseProcessor 3 | from prompt2model.dataset_processor.mock import MockProcessor 4 | from prompt2model.dataset_processor.textualize import TextualizeProcessor 5 | 6 | __all__ = ("BaseProcessor", "TextualizeProcessor", "MockProcessor") 7 | -------------------------------------------------------------------------------- /prompt2model/dataset_processor/mock.py: -------------------------------------------------------------------------------- 1 | """A mock dataset processor for testing purposes.""" 2 | from __future__ import annotations 3 | 4 | import datasets 5 | 6 | from prompt2model.dataset_processor.base import BaseProcessor 7 | 8 | 9 | class MockProcessor(BaseProcessor): 10 | """A class for retrieving datasets.""" 11 | 12 | def process_dataset_dict( 13 | self, instruction: str, dataset_dicts: list[datasets.DatasetDict] 14 | ) -> list[datasets.DatasetDict]: 15 | """A mock function to post-process a list of DatasetDicts. 16 | 17 | Args: 18 | instruction: The instruction used as a prefix to explain the task. 19 | dataset_dicts: A list of DatasetDicts (generated or retrieved). 20 | 21 | Returns: 22 | A list of DatasetDicts, all examples are converted into text2text fashion. 23 | """ 24 | _ = instruction 25 | return dataset_dicts 26 | 27 | @staticmethod 28 | def _post_process_example( 29 | example: dict, 30 | instruction: str, 31 | task_id: int, 32 | has_encoder: bool, 33 | dataset_split: str, 34 | eos_token: str, 35 | ) -> dict: 36 | """A mock function that modifies a given example dictionary. 37 | 38 | Args: 39 | example: A dictionary representing an example. 40 | instruction: The instruction used as a prefix to explain the task. 41 | task_id: A tag marking which dataset (from dataset_dicts) this example 42 | comes from. Used for multi-task training. 43 | has_encoder: Whether the retrieved model has an encoder. 44 | dataset_split: The split of the example, i.e. train/val/test. 45 | eos_token: The end-of-sentence token of the tokenizer. 46 | 47 | Returns: 48 | A dictionary with `model_input` as the input to models 49 | and `model_output` as the expected output of models. 50 | """ 51 | _ = instruction, task_id 52 | example["model_input"] = example["input_col"] 53 | example["model_output"] = example["output_col"] 54 | return example 55 | -------------------------------------------------------------------------------- /prompt2model/dataset_processor/readme.md: -------------------------------------------------------------------------------- 1 | # Dataset Processor 2 | 3 | ## Overview 4 | 5 | - `BaseProcessor`: The base class for dataset post-processing. 6 | - `TextualizeProcessor`: Transforms all datasets into a consistent 7 | text-to-text generation format. 8 | 9 | ## Getting Started 10 | 11 | - **Import the Module**: 12 | 13 | ```python 14 | from prompt2model.dataset_processor.textualize import TextualizeProcessor 15 | ``` 16 | 17 | - **Initialize TextualizeProcessor**: 18 | 19 | ```python 20 | processor = TextualizeProcessor(has_encoder=) 21 | # : Whether the model you want to finetune has an encoder. 22 | ``` 23 | 24 | Choose encoder type: 25 | 26 | - `has_encoder=True` for encoder-decoder models (e.g., T5). 27 | - `has_encoder=False` for decoder-only/autoregressive models (e.g., GPT2). 28 | 29 | - **Process Datasets**: 30 | 31 | ```python 32 | instruction = "" 33 | dataset_dicts = [...] # List of DatasetDict 34 | modified_dataset_dicts = processor.process_dataset_dict(instruction, dataset_dicts) 35 | ``` 36 | -------------------------------------------------------------------------------- /prompt2model/dataset_processor/textualize.py: -------------------------------------------------------------------------------- 1 | """A dataset processor to convert datasets into Text2Text fashion.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | from prompt2model.dataset_processor.base import BaseProcessor 6 | from prompt2model.utils.logging_utils import get_formatted_logger 7 | 8 | logger = get_formatted_logger("DatasetProcessor") 9 | 10 | 11 | class TextualizeProcessor(BaseProcessor): 12 | """A class for pre-processing datasets before training.""" 13 | 14 | def __init__(self, has_encoder: bool, eos_token: str | None = None) -> None: 15 | """Initialize the `TextualizeProcessor`. 16 | 17 | Args: 18 | has_encoder: Whether the retrieved model has an encoder. 19 | Encoder-decoder model like T5 has two model inputs. 20 | Decoder-only model like GPT only has one model input, thus 21 | `model_input` should be added with the `output_col`. 22 | eos_token: The end-of-sentence token of the tokenizer. 23 | The T5 tokenizer automatically adds eos token in the end of 24 | sequence. So only TextualizeProcessor for GPT model 25 | requires eos_token. 26 | """ 27 | super().__init__(has_encoder, eos_token) 28 | if has_encoder and eos_token is not None: 29 | logger.info( 30 | ( 31 | "The T5 tokenizer automatically adds eos token in the end of sequence when tokenizing." # noqa E501 32 | " So the eos_token of encoder-decoder model tokenizer is unnecessary." # noqa E501 33 | ) 34 | ) 35 | elif not has_encoder and eos_token is None: 36 | logger.warning( 37 | ( 38 | "The autoregressive model tokenizer does not automatically add eos token in the end of the sequence." # noqa E501 39 | " So the `eos_token` of the autoregressive model is required." # noqa E501 40 | ) 41 | ) 42 | 43 | @staticmethod 44 | def _post_process_example( 45 | example: dict, 46 | instruction: str, 47 | task_id: int, 48 | has_encoder: bool, 49 | dataset_split: str, 50 | eos_token: str | None = None, 51 | ) -> dict: 52 | """Modifies the input column of a given example dictionary. 53 | 54 | Args: 55 | example: A dictionary representing an example. 56 | instruction: The instruction used as a prefix to explain the task. 57 | task_id: A tag marking which dataset (from dataset_dicts) this example 58 | comes from. Used for multi-task training. 59 | has_encoder: Whether the retrieved model has an encoder. 60 | dataset_split: The split of the example, i.e. train/val/test. 61 | eos_token: The end-of-sentence token of the tokenizer. 62 | 63 | Returns: 64 | A dictionary with `model_input` as the input to models 65 | and `model_output` as the expected output of models. 66 | """ 67 | if dataset_split not in ( 68 | "train", 69 | "val", 70 | "test", 71 | ): 72 | raise ValueError("Datset split must be in train/val/test.") 73 | example["output_col"] = str(example["output_col"]) 74 | if has_encoder: 75 | model_input = f"{instruction}\nExample:\n{example['input_col']}\nLabel:\n" # noqa E501 76 | model_output = example["output_col"] 77 | else: 78 | # The T5 tokenizer automatically adds eos token in `add_eos_if_not_present`. 79 | # On the contrary, model_output of GPT model need eos token in the end. 80 | if dataset_split == "train": 81 | model_output = example["output_col"] + eos_token 82 | model_input = f"{instruction}\nExample:\n{example['input_col']}\nLabel:\n{model_output}" # noqa E501 83 | else: 84 | # The val/test split is only used for evaluation. Since our decode 85 | # method in the ModelExecutor set `skip_special_tokens=True`, 86 | # we do not need to add eos token in the end. 87 | model_output = example["output_col"] 88 | model_input = f"{instruction}\nExample:\n{example['input_col']}\nLabel:\n" # noqa E501 89 | example["model_input"] = model_input 90 | example["model_output"] = model_output 91 | return example 92 | -------------------------------------------------------------------------------- /prompt2model/dataset_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | """Import DatasetRetriever classes.""" 2 | from prompt2model.dataset_retriever.base import DatasetRetriever 3 | from prompt2model.dataset_retriever.description_dataset_retriever import ( 4 | DatasetInfo, 5 | DescriptionDatasetRetriever, 6 | ) 7 | from prompt2model.dataset_retriever.mock import MockRetriever 8 | 9 | __all__ = ( 10 | "DatasetRetriever", 11 | "MockRetriever", 12 | "DescriptionDatasetRetriever", 13 | "DatasetInfo", 14 | ) 15 | -------------------------------------------------------------------------------- /prompt2model/dataset_retriever/base.py: -------------------------------------------------------------------------------- 1 | """An interface for dataset retrieval.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | import dataclasses 6 | from abc import ABC, abstractmethod 7 | 8 | import datasets 9 | 10 | from prompt2model.prompt_parser import PromptSpec 11 | 12 | 13 | @dataclasses.dataclass 14 | class DatasetInfo: 15 | """Store the dataset name, description, and query-dataset score for each dataset. 16 | 17 | Args: 18 | name: The name of the dataset. 19 | description: The description of the dataset. 20 | score: The retrieval score of the dataset. 21 | """ 22 | 23 | name: str 24 | description: str 25 | score: float 26 | 27 | 28 | # pylint: disable=too-few-public-methods 29 | class DatasetRetriever(ABC): 30 | """A class for retrieving datasets.""" 31 | 32 | @abstractmethod 33 | def retrieve_dataset_dict( 34 | self, prompt_spec: PromptSpec 35 | ) -> datasets.DatasetDict | None: 36 | """Retrieve full dataset splits (e.g. train/dev/test) from a prompt. 37 | 38 | Args: 39 | prompt_spec: A prompt spec (containing a system description). 40 | 41 | Returns: 42 | A retrieved DatasetDict containing train/val/test splits. 43 | """ 44 | -------------------------------------------------------------------------------- /prompt2model/dataset_retriever/column_selection_prompt.py: -------------------------------------------------------------------------------- 1 | """Utilities to construct an LLM "metaprompt" for our column selection.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | METAPROMPT_BASE = """Your objective is to carefully analyze the task and the dataset mentioned, and decide whether the columns are relevant input, relevant output, irrelevant for the given task, or if it is ambiguous. There should be at most one output column. It is possible to have no relevant columns, in which case return the input and output column as empty lists. Answer in a json format, with the following keys: input, output, irrelevant, ambiguous""" # noqa: E501 6 | METAPROMPT_EXAMPLES = [ 7 | ( 8 | """You are tasked with the following process. In this task, you will generate summaries for given texts. For this task, you will use the Scientific Papers dataset from HuggingFace. Dataset_description: Scientific papers datasets contains two sets of long and structured documents. The datasets are obtained from ArXiv and PubMed OpenAccess repositories. 9 | A sample data instance from this dataset is as follows. 10 | { 11 | "abstract": "\" we have studied the leptonic decay @xmath0 , via the decay channel @xmath1 , using a sample of tagged @xmath2 decays collected...", 12 | "article": "\"the leptonic decays of a charged pseudoscalar meson @xmath7 are processes of the type @xmath8 , where @xmath9 , @xmath10 , or @...", 13 | "section_names": "[sec:introduction]introduction\n[sec:detector]data and the cleo- detector\n[sec:analysys]analysis method\n[sec:conclusion]summary" 14 | } 15 | This dataset has the following columns: [abstract, article, section_names]. 16 | """, # noqa: E501 17 | """ 18 | { 19 | "input": ["article"], 20 | "output": ["abstract"], 21 | "irrelevant": ["section_names"], 22 | "ambiguous": [] 23 | }""", 24 | ), 25 | ( 26 | """ 27 | You are tasked with the following process. In this task, you will detect whether some given text uses hateful speech or not. For this task you will use the hate_speech_offensive dataset from HuggingFace. Dataset_description: An annotated dataset for hate speech and offensive language detection on tweets. 28 | A sample data instance from this is as follows: 29 | { 30 | "count": 3, 31 | "hate_speech_count": 0, 32 | "offensive_language_count": 0, 33 | "neither_count": 3, 34 | "label": 2, # "neither" 35 | "tweet": "!!! RT @mayasolovely: As a woman you shouldn't complain about cleaning up your house. & as a man you should always take the trash out...") 36 | }. 37 | This dataset has the following columns: [count, hate_speech_count, offensive_language_count, neither_count, class, tweet]""", # noqa: E501 38 | """ 39 | { 40 | "input": ["tweet"], 41 | "output": ["label"], 42 | "irrelevant": [], 43 | "ambiguous": ["hate_speech_count", "offensive_language_count", "neither_count", "count"] 44 | }""", # noqa: E501 45 | ), 46 | ( 47 | """You are tasked with the following process. Your job is to be able to translate between languages. For this task, you will use a custom dataset. Dataset_description: This dataset is meant to translate between languages. 48 | A sample data instance from this is as follows: 49 | { 50 | "translation": ["ca: "El department de bombers té el seu propi equip d'investigació.", "en": "Well, the fire department has its own investigative unit."] 51 | } 52 | This dataset has the following columns: [translation]. """, # noqa: E501 53 | """ 54 | { 55 | "input": [], 56 | "output": [], 57 | "irrelevant": [] 58 | "ambiguous": ["translation"] 59 | }""", 60 | ), 61 | ( 62 | """You are tasked with the following process. Your job is to be able to summarize a given text. For this task, you will use the math_qa dataset from HuggingFace. Dataset_description: Our dataset is gathered by using a new representation language to annotate over the AQuA-RAT dataset with fully-specified operational programs. 63 | A sample data instance from this is as follows: 64 | { 65 | "Problem": "a multiple choice test consists of 4 questions , and each question has 5 answer choices . in how many r ways can the test be completed if every question is unanswered ?", 66 | "Rationale": "\"5 choices for each of the 4 questions , thus total r of 5 * 5 * 5 * 5 = 5 ^ 4 = 625 ways to answer all of them . answer : c .\"", 67 | "annotated_formula": "power(5, 4)", 68 | "category": "general", 69 | "correct": "c", 70 | "linear_formula": "power(n1,n0)|", 71 | "options": "a ) 24 , b ) 120 , c ) 625 , d ) 720 , e ) 1024" 72 | } 73 | This dataset has the following columns: [problem, rationale, options, correct, annotated_formula]. """, # noqa: E501 74 | """ 75 | { 76 | "input": [], 77 | "output": [], 78 | "irrelevant": ["problem", "rationale", "options", "correct", "annotated_formula"], 79 | "ambiguous": [] 80 | }""", # noqa: E501 81 | ), 82 | ] 83 | 84 | INPUT_PROMPT_TEMPLATE = """You are tasked with the following process. {instruction} For this task, you will use the {dataset_name} dataset from HuggingFace. Dataset Description: {dataset_description} \nA sample data instance from this is as follows. {sample_row}.\nThis dataset has the following columns: [{dataset_columns} ].""" # noqa: E501 85 | SINGLE_DEMONSTRATION_TEMPLATE = ( 86 | 'Task: """\n{prompt}\n"""\n\nRequired Columns :\n{columns}' 87 | ) 88 | ENDING_LINE = "After seeing these examples with the required columns, please provide the relevant columns for this context:" # noqa: E501 89 | 90 | 91 | def build_input( 92 | instruction: str, 93 | dataset_name: str, 94 | dataset_description: str, 95 | dataset_columns: str, 96 | sample_row: dict, 97 | ) -> str: 98 | """Template function to build input based on arguments.""" 99 | input_prompt = INPUT_PROMPT_TEMPLATE.format( 100 | instruction=instruction, 101 | dataset_name=dataset_name, 102 | dataset_description=dataset_description, 103 | dataset_columns=dataset_columns, 104 | sample_row=sample_row, 105 | ) 106 | input_prompt = SINGLE_DEMONSTRATION_TEMPLATE.format( 107 | prompt=input_prompt, columns="" 108 | ) # columns="" because that is what we are trying to predict 109 | return input_prompt 110 | 111 | 112 | def construct_prompt_for_column_selection( 113 | instruction: str, 114 | dataset_name: str, 115 | dataset_description: str, 116 | dataset_columns: str, 117 | sample_row: dict, 118 | ) -> str: 119 | """Generate prompt for column selection.""" 120 | prompt_sections = [METAPROMPT_BASE] 121 | for prompt, columns in METAPROMPT_EXAMPLES: 122 | prompt_sections.append( 123 | SINGLE_DEMONSTRATION_TEMPLATE.format(prompt=prompt, columns=columns) 124 | ) 125 | all_prompts = "\n\n------\n\n".join(prompt_sections) + "\n\n------\n\n" 126 | input_prompt = build_input( 127 | instruction, dataset_name, dataset_description, dataset_columns, sample_row 128 | ) 129 | all_prompts += ENDING_LINE + input_prompt 130 | 131 | return all_prompts 132 | -------------------------------------------------------------------------------- /prompt2model/dataset_retriever/mock.py: -------------------------------------------------------------------------------- 1 | """A mock dataset retriever for testing purposes.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | import datasets 6 | 7 | from prompt2model.dataset_retriever.base import DatasetRetriever 8 | from prompt2model.prompt_parser import PromptSpec 9 | 10 | 11 | class MockRetriever(DatasetRetriever): 12 | """A class for retrieving datasets.""" 13 | 14 | def __init__(self): 15 | """Construct a mock dataset retriever.""" 16 | 17 | def retrieve_dataset_dict( 18 | self, prompt_spec: PromptSpec 19 | ) -> datasets.DatasetDict | None: 20 | """Return a single empty DatasetDict for testing purposes.""" 21 | _ = prompt_spec # suppress unused vaiable warning 22 | mock_dataset = datasets.Dataset.from_dict( 23 | {"input_col": [""], "output_col": [""]} 24 | ) 25 | return [ 26 | datasets.DatasetDict( 27 | {"train": mock_dataset, "val": mock_dataset, "test": mock_dataset} 28 | ) 29 | ] 30 | -------------------------------------------------------------------------------- /prompt2model/dataset_retriever/readme.md: -------------------------------------------------------------------------------- 1 | # Dataset Retriever 2 | 3 | ## Overview 4 | 5 | - `DatasetRetriever`: Interface for retrieving datasets based on a 6 | prompt. 7 | - `DescriptionDatasetRetriever`: Retrieves HuggingFace datasets using 8 | similarity to a given prompt. 9 | 10 | ## Getting Started 11 | 12 | - Import Modules 13 | 14 | ```python 15 | from prompt2model.dataset_retriever import DescriptionDatasetRetriever 16 | from prompt2model.prompt_parser import MockPromptSpec, TaskType 17 | ``` 18 | 19 | - Initialize Retriever 20 | 21 | ```python 22 | retriever = DescriptionDatasetRetriever() 23 | ``` 24 | 25 | Various parameters like search index path, model name, and search 26 | depth can be customized during initialization. 27 | 28 | - Prepare the Prompt 29 | 30 | ```python 31 | task_type = TaskType.TEXT_GENERATION 32 | prompt_text = "..." 33 | prompt_spec = MockPromptSpec(task_type) 34 | prompt_spec._instruction = prompt_text 35 | ``` 36 | 37 | - Retrieve Dataset 38 | 39 | ```python 40 | dataset_dict = retriever.retrieve_dataset_dict( 41 | prompt_spec, blocklist=[] 42 | ) 43 | ``` 44 | 45 | `dataset_dict` will contain the dataset splits (train/val/test) most 46 | relevant to the given prompt. 47 | -------------------------------------------------------------------------------- /prompt2model/dataset_retriever/reranking_prompt.py: -------------------------------------------------------------------------------- 1 | """This module contains the functions to generate the prompt for dataset reranking.""" 2 | from __future__ import annotations # noqa FI58 3 | 4 | METAPROMPT_BASE_DATASET = """Your objective is to choose the most relevant dataset for a given a task (and few examples of the task). For each dataset, you will be provided with the dataset description, and tags related to the dataset which provide meta-information about the dataset. Please return the most relevant dataset, e.g. squad """ # noqa: E501 5 | 6 | METAPROMPT_BASE_CONFIG = """Your objective is to choose the most relevant config of a dataset for a given a task (and few examples of the task). A config of a dataset is a version of that dataset. You will be provided information about this dataset, followed by information about its configs. For each config, you will be provided with the config name, and columns and rows of that config. The columns of the config could be useful at understanding whether this config is relevant to the given task. Another relevant factor is the config_name, this would give information on a high level about what each config represents. Please return the most relevant config""" # noqa: E501 7 | 8 | INPUT_PROMPT_DATASET_TEMPLATE = """The following is the task \n {instruction} \n and these are some examples of the same: \n{examples} \n 9 | There are {num} datasets available for this task. \n 10 | {datasets}. 11 | The name of the most relevant dataset for this task is:""" # noqa: E501 12 | 13 | INPUT_PROMPT_CONFIG_TEMPLATE = """ The following is the task: \n {instruction} \n 14 | The following is the dataset selected: {dataset_name}: {dataset_description} 15 | There are {num} configs available in this dataset for this task. \n 16 | {configs} 17 | The name of the most relevant config from these for this task is: 18 | """ 19 | 20 | DATASET_TEMPLATE = """[{counter}] **{dataset_name}**:\nDescription-{dataset_description}.\nThis dataset has the following tags:\n {tags} """ # noqa: E501 21 | CONFIG_TEMPLATE = """\t[{counter}] **{config_name}**\n: The columns in this config are {dataset_columns}.\n An example row from this config is {sample_row}.\n """ # noqa: E501 22 | 23 | 24 | def build_datasets_prompt(instruction: str, examples: str, datasets_infos: dict): 25 | """Builds the prompt for dataset reranking. 26 | 27 | Args: 28 | instruction (str): Task instructions 29 | examples (str): Task Examples 30 | datasets_infos (dict): A dictionary containing information about all datasets. 31 | 32 | Returns: 33 | str: The input prompt for dataset retrieval. 34 | """ 35 | dataset_string = "" 36 | for i, (dataset_name, dataset_info) in enumerate(datasets_infos.items(), start=1): 37 | dataset_string += f"""{DATASET_TEMPLATE.format( 38 | counter = i, 39 | dataset_name=dataset_name, 40 | dataset_description=dataset_info["description"], 41 | tags = dataset_info["tags"] 42 | )}\n\n""" 43 | 44 | input_prompt = INPUT_PROMPT_DATASET_TEMPLATE.format( 45 | instruction=instruction, 46 | examples=examples, 47 | datasets=dataset_string, 48 | num=len(datasets_infos), 49 | ) 50 | return input_prompt 51 | 52 | 53 | def build_configs_prompt(instruction: str, examples: str, dataset_info: dict): 54 | """Builds the prompt for config reranking. 55 | 56 | Args: 57 | instruction (str): Task instructions 58 | examples (str): Task Examples 59 | datasets_infos (dict): A dictionary containing information about 60 | the specific dataset, which includes config information. 61 | 62 | Returns: 63 | str: The input prompt for dataset retrieval. 64 | """ 65 | configs_string = "" 66 | for j, (config_name, config_info) in enumerate(dataset_info["configs"].items()): 67 | configs_string += f"""{CONFIG_TEMPLATE.format( 68 | counter = chr(ord('a')+j), 69 | config_name = config_name, 70 | dataset_columns = config_info["columns"], 71 | sample_row = config_info["sample_row"] 72 | )}\n""" # noqa: E501 73 | 74 | input_prompt = INPUT_PROMPT_CONFIG_TEMPLATE.format( 75 | instruction=instruction, 76 | examples=examples, 77 | dataset_name=dataset_info["dataset_name"], 78 | dataset_description=dataset_info["description"], 79 | configs=configs_string, 80 | num=len(dataset_info["configs"]), 81 | ) 82 | return input_prompt 83 | 84 | 85 | def construct_prompt_for_dataset_reranking( 86 | instruction: str, 87 | examples: str, 88 | datasets_infos: dict, 89 | is_config: bool = False, 90 | ): 91 | """Generate the full prompt for dataset reranking based on the given parameters. 92 | 93 | Args: 94 | instruction (str): Instruction of the task. 95 | examples (str): Examples of the task. 96 | datasets_infos (dict): Dictionary with dataset/config information. Each 97 | dataset_info object also has a configs object 98 | representing the various configs of that dataset 99 | is_config (bool): bool: Whether the prompt is for dataset 100 | reranking or config reranking 101 | 102 | Returns: 103 | str: Builds a comprehensive prompt for dataset reranking. This prompt includes 104 | the base instructions, incontext example and the prompt returned by the 105 | build_input function. 106 | """ 107 | if is_config: 108 | metaprompt_base = METAPROMPT_BASE_CONFIG 109 | input_prompt = build_configs_prompt(instruction, examples, datasets_infos) 110 | else: 111 | metaprompt_base = METAPROMPT_BASE_DATASET 112 | input_prompt = build_datasets_prompt(instruction, examples, datasets_infos) 113 | 114 | prompt_sections = [metaprompt_base] 115 | all_prompts = "\n\n------\n\n".join(prompt_sections) + "\n\n------\n\n" 116 | all_prompts += input_prompt 117 | 118 | return all_prompts 119 | -------------------------------------------------------------------------------- /prompt2model/dataset_retriever/run_dataset_retriever.py: -------------------------------------------------------------------------------- 1 | """Script to run the dataset retriever in isolation.""" 2 | from prompt2model.dataset_retriever import DescriptionDatasetRetriever 3 | from prompt2model.prompt_parser import MockPromptSpec, TaskType 4 | 5 | if __name__ == "__main__": 6 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) 7 | prompt = """Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.""" # noqa E501 8 | prompt_spec._instruction = prompt 9 | 10 | retriever = DescriptionDatasetRetriever() 11 | retriever.retrieve_dataset_dict(prompt_spec) 12 | -------------------------------------------------------------------------------- /prompt2model/dataset_retriever/task_expansion_prompt.py: -------------------------------------------------------------------------------- 1 | """This module contains the functions to construct the prompt for task expansion.""" 2 | METAPROMPT_BASE = "Carefully analyse the task description and examples of the task, and explain the task to give a clearer description. Do not explain each example, but rather capture the general trends. Also place special focus on the format of the input/output examples." # noqa: E501 3 | 4 | TASK = """ 5 | Task Description: {task_description} 6 | 7 | Task Examples: {examples} 8 | """ 9 | 10 | 11 | def construct_prompt_for_task_explanation(instruction: str, demonstrations: str): 12 | """Constructs prompt for task explanation. 13 | 14 | This is useful for clarifying the requirements of a task, 15 | and providing a clearer description of the task. 16 | 17 | Args: 18 | instruction (str): The task instruction. 19 | demonstrations (str): The task demonstrations. 20 | 21 | Returns: 22 | str: The constructed prompt. 23 | """ 24 | task = TASK.format(task_description=instruction, examples=demonstrations) 25 | prompt = "\n--------\n".join([METAPROMPT_BASE, task]) 26 | return prompt 27 | -------------------------------------------------------------------------------- /prompt2model/dataset_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | """Import DatasetGenerator classes.""" 2 | from prompt2model.dataset_transformer.base import DatasetTransformer 3 | from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer 4 | 5 | __all__ = ( 6 | "PromptBasedDatasetTransformer", 7 | "DatasetTransformer", 8 | ) 9 | -------------------------------------------------------------------------------- /prompt2model/dataset_transformer/base.py: -------------------------------------------------------------------------------- 1 | """An interface for dataset transformation.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | import datasets 8 | 9 | from prompt2model.prompt_parser import PromptSpec 10 | 11 | 12 | class DatasetTransformer(ABC): 13 | """A class for transforming a given dataset to a desired format.""" 14 | 15 | @abstractmethod 16 | def transform_data( 17 | self, 18 | prompt_spec: PromptSpec, 19 | dataset: datasets.Dataset, 20 | ) -> datasets.Dataset: 21 | """Transform a split of data. 22 | 23 | Args: 24 | prompt_spec: A prompt spec (containing a system description). 25 | dataset: A dataset split. 26 | num_points_to_transform: Number of data points you wish to 27 | transform. Number must be greater than zero. If number is greater 28 | than size of dataset, whole dataset will be transformed. Ignored 29 | if data_transform is False. 30 | 31 | Returns: 32 | A single dataset split. 33 | """ 34 | -------------------------------------------------------------------------------- /prompt2model/demo_creator/__init__.py: -------------------------------------------------------------------------------- 1 | """Import DemoCreator functions.""" 2 | 3 | from prompt2model.demo_creator.create import create_gradio 4 | from prompt2model.demo_creator.mock import mock_gradio_create 5 | 6 | __all__ = ( 7 | "mock_gradio_create", 8 | "create_gradio", 9 | ) 10 | -------------------------------------------------------------------------------- /prompt2model/demo_creator/create.py: -------------------------------------------------------------------------------- 1 | """Create a Gradio interface automatically.""" 2 | 3 | import gradio as gr 4 | import mdtex2html 5 | 6 | from prompt2model.dataset_processor import TextualizeProcessor 7 | from prompt2model.model_executor import GenerationModelExecutor 8 | from prompt2model.prompt_parser import PromptBasedInstructionParser 9 | 10 | 11 | def create_gradio( 12 | model_executor: GenerationModelExecutor, prompt_parser: PromptBasedInstructionParser 13 | ) -> gr.Blocks: 14 | """Create a Gradio interface automatically. 15 | 16 | Args: 17 | model_executor: A GenerationModelExecutor to expose via a Gradio interface. 18 | prompt_parser: An instance of PromptBasedInstructionParser to parse the prompt. 19 | 20 | Returns: 21 | A Gradio interface for interacting with the model. 22 | 23 | """ 24 | description = prompt_parser.instruction 25 | examples = prompt_parser.examples 26 | 27 | def postprocess(self, y): 28 | if y is None: 29 | return [] 30 | for i, (message, response) in enumerate(y): 31 | y[i] = ( 32 | None if message is None else mdtex2html.convert((message)), 33 | None if response is None else mdtex2html.convert(response), 34 | ) 35 | return y 36 | 37 | gr.Chatbot.postprocess = postprocess 38 | 39 | def response(message: str): 40 | if not message.startswith(""): 41 | dataset_processor = TextualizeProcessor(has_encoder=True) 42 | message = dataset_processor.wrap_single_input( 43 | prompt_parser.instruction, message 44 | ) 45 | response = model_executor.make_single_prediction(message) 46 | prediction = response.prediction 47 | return prediction 48 | 49 | def chat(message, history): 50 | history = history or [] 51 | model_output = ( 52 | response(message) if message != "" else "Please give valid input." 53 | ) 54 | history.append((message, model_output)) 55 | return history, history 56 | 57 | def reset_user_input(): 58 | return gr.update(value="") 59 | 60 | def reset_state(): 61 | return [], [] 62 | 63 | with gr.Blocks() as demo: 64 | gr.HTML("""

Prompt2Model

""") 65 | gr.HTML(f"""

Task Description: {description}

""") 66 | gr.HTML(f"""

Few-shot Examples: {examples}

""") 67 | 68 | chatbot = gr.Chatbot() 69 | with gr.Row(): 70 | with gr.Column(scale=4): 71 | with gr.Column(scale=12): 72 | user_input = gr.Textbox( 73 | show_label=False, placeholder="Input...", lines=10 74 | ).style(container=False) 75 | with gr.Column(min_width=32, scale=2): 76 | submitBtn = gr.Button("Submit", variant="primary") 77 | emptyBtn = gr.Button("Clear History") 78 | 79 | history = gr.State([]) 80 | 81 | submitBtn.click( 82 | chat, [user_input, history], [chatbot, history], show_progress=True 83 | ) 84 | submitBtn.click(reset_user_input, [], [user_input]) 85 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) 86 | 87 | return demo 88 | -------------------------------------------------------------------------------- /prompt2model/demo_creator/mock.py: -------------------------------------------------------------------------------- 1 | """An interface for creating Gradio demos automatically.""" 2 | 3 | import gradio as gr 4 | import transformers 5 | 6 | from prompt2model.prompt_parser.base import PromptSpec 7 | 8 | 9 | def mock_gradio_create( 10 | model: transformers.PreTrainedModel, prompt_spec: PromptSpec 11 | ) -> gr.Interface: 12 | """Create a Gradio interface automatically. 13 | 14 | Args: 15 | model: A trained model to expose via a Gradio interface. 16 | prompt_spec: A PromptSpec to help choose the visual interface. 17 | 18 | Returns: 19 | A Gradio interface for interacting with the model. 20 | 21 | """ 22 | _ = model, prompt_spec # suppress unused variable warnings 23 | dummy_interface = gr.Interface(lambda input: None, "textbox", "label") 24 | return dummy_interface 25 | -------------------------------------------------------------------------------- /prompt2model/demo_creator/readme.md: -------------------------------------------------------------------------------- 1 | # Demo Creator 2 | 3 | ## Overview 4 | 5 | - **create_gradio**: A function to set up and return a Gradio 6 | interface for model interactions. 7 | 8 | ## Getting Started 9 | 10 | - **Import the Necessary Modules**: 11 | 12 | ```python 13 | from prompt2model.model_executor import GenerationModelExecutor 14 | from prompt2model.prompt_parser import PromptBasedInstructionParser 15 | from prompt2model.gradio_interface import create_gradio 16 | ``` 17 | 18 | - **Initialize Components**: 19 | 20 | ```python 21 | model_executor = GenerationModelExecutor(...) 22 | prompt_parser = PromptBasedInstructionParser(...) 23 | # Refer to the documentation of ModelExecutor and PromptParser for details. 24 | ``` 25 | 26 | - **Create and Run the Gradio Interface**: 27 | 28 | ```python 29 | interface = create_gradio(model_executor, prompt_parser) 30 | interface.launch(shared=True) 31 | ``` 32 | -------------------------------------------------------------------------------- /prompt2model/model_evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | """Import evaluator classes.""" 2 | from prompt2model.model_evaluator.base import ModelEvaluator 3 | from prompt2model.model_evaluator.mock import MockEvaluator 4 | from prompt2model.model_evaluator.seq2seq import Seq2SeqEvaluator 5 | 6 | __all__ = ("MockEvaluator", "ModelEvaluator", "Seq2SeqEvaluator") 7 | -------------------------------------------------------------------------------- /prompt2model/model_evaluator/base.py: -------------------------------------------------------------------------------- 1 | """An interface for automatic model evaluation.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | import json 6 | from abc import ABC, abstractmethod 7 | from typing import Any 8 | 9 | import datasets 10 | import evaluate 11 | 12 | from prompt2model.model_executor import ModelOutput 13 | 14 | 15 | class ModelEvaluator(ABC): 16 | """An interface for automatic model evaluation.""" 17 | 18 | @abstractmethod 19 | def evaluate_model( 20 | self, 21 | dataset: datasets.Dataset, 22 | gt_column: str, 23 | predictions: list[ModelOutput], 24 | model_input_column: str | None = None, 25 | metrics: list[evaluate.Metric] | None = None, 26 | encoder_model_name: str = "xlm-roberta-base", 27 | ) -> dict[str, Any]: 28 | """Evaluate a model on a test set.. 29 | 30 | Args: 31 | dataset: The dataset to evaluate metrics on. 32 | gt_column: The dataset column to use as ground truth. 33 | predictions: Model outputs to evaluate. 34 | metrics: (Optional) The metrics to use, defaults to using chr_f, 35 | exact_match, and bert_score. 36 | prompt_spec: (Optional) A PromptSpec to infer the metrics from. 37 | 38 | Returns: 39 | A dictionary of metric values to return. 40 | """ 41 | 42 | def write_metrics(self, metrics_dict: dict[str, Any], metrics_path: str) -> None: 43 | """This function writes metrics to a file. 44 | 45 | Args: 46 | metrics_dict: A dictionary of metrics to write. 47 | metrics_path: The file path to write metrics to. 48 | 49 | """ 50 | with open(metrics_path, "w") as f: 51 | json.dump(metrics_dict, f) 52 | -------------------------------------------------------------------------------- /prompt2model/model_evaluator/mock.py: -------------------------------------------------------------------------------- 1 | """A dummy evaluator for testing purposes.""" 2 | from __future__ import annotations # noqa FI58 3 | 4 | from typing import Any 5 | 6 | import datasets 7 | import evaluate 8 | 9 | from prompt2model.model_evaluator.base import ModelEvaluator 10 | from prompt2model.model_executor import ModelOutput 11 | 12 | 13 | class MockEvaluator(ModelEvaluator): 14 | """A dummy evaluator that always returns the same metric value.""" 15 | 16 | def __init__(self) -> None: 17 | """Initialize the evaluation setting.""" 18 | 19 | def evaluate_model( 20 | self, 21 | dataset: datasets.Dataset, 22 | gt_column: str, 23 | predictions: list[ModelOutput], 24 | model_input_column: str | None = None, 25 | metrics: list[evaluate.Metric] | None = None, 26 | encoder_model_name: str = "xlm-roberta-base", 27 | ) -> dict[str, Any]: 28 | """Return empty metrics dictionary. 29 | 30 | Args: 31 | dataset: The dataset to evaluate metrics on. 32 | gt_column: The dataset column to use as ground truth. 33 | predictions: Corresponding model outputs to evaluate. 34 | metrics: (Optional) The metrics to use, defaults to using 35 | chr_f, exact_match, and bert_score. 36 | prompt_spec: (Optional) A PromptSpec to infer the metrics from. 37 | 38 | Returns: 39 | An empty dictionary (for testing purposes). 40 | """ 41 | return {} 42 | -------------------------------------------------------------------------------- /prompt2model/model_evaluator/readme.md: -------------------------------------------------------------------------------- 1 | # Model Evaluator 2 | 3 | ## Overview 4 | 5 | - **ModelEvaluator**: Interface for evaluating models‘ outputs. 6 | - **Seq2SeqEvaluator**: Offers metrics (`ChrF++`, `Exact Match`, 7 | `BERTScore`) for evaluating conditional generation models on specific 8 | datasets. 9 | 10 | ## Getting Started 11 | 12 | - **Import Required Modules**: 13 | 14 | ```python 15 | from prompt2model.evaluator import Seq2SeqEvaluator 16 | from prompt2model.model_executor import ModelOutput 17 | ``` 18 | 19 | - **Instantiate Seq2SeqEvaluator**: 20 | 21 | ```python 22 | evaluator = Seq2SeqEvaluator() 23 | ``` 24 | 25 | - **Prepare Dataset & Predictions**: 26 | 27 | 1. Autoregressive models might include the input in their outputs. For 28 | such evaluations, refer to the document string of 29 | `model_input_column` in `Seq2SeqEvaluator.evaluate_model`. 30 | 2. Default metrics: `ChrF++`, `Exact Match`, `BERTScore`. 31 | 32 | ```python 33 | PREDICTIONS = [...] 34 | VALIDATION_DATASET = ... 35 | # Dataset with ground truth column `gt_column` and optionally input column `model_input_column`. 36 | ``` 37 | 38 | - **Evaluate**: 39 | 40 | ```python 41 | metric_values = evaluator.evaluate_model( 42 | dataset=VALIDATION_DATASET, 43 | gt_column="model_ouput", 44 | predictions=PREDICTIONS, 45 | ) 46 | ``` 47 | -------------------------------------------------------------------------------- /prompt2model/model_evaluator/seq2seq.py: -------------------------------------------------------------------------------- 1 | """An interface for automatically evaluate Seq2Seq generation model.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | from typing import Any 6 | 7 | import datasets 8 | import evaluate 9 | import numpy as np 10 | 11 | from prompt2model.model_evaluator.base import ModelEvaluator 12 | from prompt2model.model_executor import ModelOutput 13 | from prompt2model.utils import get_formatted_logger 14 | 15 | logger = get_formatted_logger("ModelEvaluator") 16 | 17 | 18 | class Seq2SeqEvaluator(ModelEvaluator): 19 | """An evaluator computing `chr_f++`, `Exact Match` and `Embedding Distance`.""" 20 | 21 | def evaluate_model( 22 | self, 23 | dataset: datasets.Dataset, 24 | gt_column: str, 25 | predictions: list[ModelOutput], 26 | model_input_column: str | None = None, 27 | metrics: list[evaluate.Metric] | None = None, 28 | encoder_model_name: str = "xlm-roberta-base", 29 | ) -> dict[str, Any]: 30 | """Evaluate a model on a test set. 31 | 32 | Args: 33 | dataset: The dataset to evaluate metrics on. 34 | gt_column: The dataset column to use as ground truth. 35 | predictions: Model outputs to evaluate. 36 | model_input_column: (Optional) For autoregistered models, 37 | the prediction sometimes contains the model input. 38 | So we need to delete the model input if it's in the predictions. 39 | metrics: (Optional) The metrics to use, defaults to using 40 | chr_f, exact_match, and bert_score. 41 | 42 | Returns: 43 | A dictionary of metric values to return. 44 | """ 45 | if metrics is not None: 46 | metric_names = [each.name for each in metrics] 47 | metric_names = sorted(metric_names, key=lambda name: name.lower()) 48 | if not ( 49 | set(metric_names) 50 | < { 51 | "chr_f", 52 | "exact_match", 53 | "bert_score", 54 | } 55 | ): 56 | raise ValueError( 57 | "Metrics must be within chr_f, exact_match, and bert_score." 58 | ) 59 | logger.info(f"Using selected metrics: {', '.join(metric_names)}.") 60 | else: 61 | logger.info("Using default metrics of chr_f, exact_match and bert_score.") 62 | metrics = [ 63 | evaluate.load("chrf"), 64 | evaluate.load("exact_match"), 65 | evaluate.load("bertscore"), 66 | ] 67 | # Get the ground truth from the dataset 68 | ground_truths = dataset[gt_column] 69 | # Extract the predicted strings from ModelOutput 70 | predicted_strings = [each.prediction for each in predictions] 71 | if len(ground_truths) != len(predicted_strings): 72 | raise ValueError( 73 | "The length of input dataset and predictions are not equal." 74 | ) 75 | # Initialize the metric values dictionary 76 | metric_values = {} 77 | 78 | if model_input_column is not None: 79 | # Some of the autoregregistered models' output always contains 80 | # the input. So we need to delete the model input if it's in the 81 | # predictions when necessary. 82 | logger.info( 83 | "The model_input_column is not None. The model input will be detached from predictions if necessary." # noqa E501 84 | ) 85 | model_inputs = dataset[model_input_column] 86 | for idx, model_input in enumerate(model_inputs): 87 | ground_truth = ground_truths[idx] 88 | predicted_string = predicted_strings[idx] 89 | if (model_input in predicted_string) and ( 90 | model_input not in ground_truth 91 | ): 92 | predicted_string = predicted_string.replace(model_input, "") 93 | predicted_strings[idx] = predicted_string 94 | 95 | # Compute and store metric values 96 | for metric in metrics: 97 | metric_name = metric.name 98 | metric.add_batch(predictions=predicted_strings, references=ground_truths) 99 | if metric_name == "chr_f": 100 | metric_values["chr_f++"] = metric.compute(word_order=2)["score"] 101 | elif metric_name == "exact_match": 102 | metric_values[metric_name] = metric.compute()["exact_match"] 103 | elif metric_name == "bert_score": 104 | metric_values["average_bert_score"] = np.average( 105 | metric.compute(model_type=encoder_model_name)["f1"] 106 | ) 107 | 108 | return metric_values 109 | -------------------------------------------------------------------------------- /prompt2model/model_executor/__init__.py: -------------------------------------------------------------------------------- 1 | """Import all the model executor classes.""" 2 | 3 | from prompt2model.model_executor.base import ModelExecutor, ModelOutput 4 | from prompt2model.model_executor.generate import GenerationModelExecutor 5 | from prompt2model.model_executor.mock import MockModelExecutor 6 | 7 | __all__ = ( 8 | "ModelExecutor", 9 | "ModelOutput", 10 | "MockModelExecutor", 11 | "GenerationModelExecutor", 12 | ) 13 | -------------------------------------------------------------------------------- /prompt2model/model_executor/base.py: -------------------------------------------------------------------------------- 1 | """An interface for generating model outputs.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | from abc import ABC, abstractmethod 6 | from dataclasses import dataclass 7 | from typing import Any 8 | 9 | import datasets 10 | import transformers 11 | 12 | from prompt2model.utils import get_formatted_logger 13 | 14 | logger = get_formatted_logger("ModelExecutor") 15 | 16 | 17 | @dataclass(frozen=False) 18 | class ModelOutput: 19 | """A model output for a single example. 20 | 21 | Attributes: 22 | prediction: The prediction by the model. 23 | auxiliary_info: Any other auxiliary information provided by the model. 24 | """ 25 | 26 | prediction: Any 27 | auxiliary_info: dict[str, Any] 28 | 29 | 30 | class ModelExecutor(ABC): 31 | """An interface for automatic model evaluation.""" 32 | 33 | def __init__( 34 | self, 35 | model: transformers.PreTrainedModel, 36 | tokenizer: transformers.PreTrainedTokenizer, 37 | batch_size: int = 10, 38 | tokenizer_max_length: int = 256, 39 | sequence_max_length: int = 512, 40 | ) -> None: 41 | """Initializes a new instance of ModelExecutor. 42 | 43 | Args: 44 | model: The model to evaluate. 45 | tokenizer: The model's associated tokenizer. 46 | batch_size: The batch size to use when making predictions. 47 | tokenizer_max_length: The maximum number of tokens that 48 | tokenizer is allowed to generate. 49 | sequence_max_length: The maximum number of tokens in 50 | the input and output. 51 | """ 52 | self.model = model 53 | self.tokenizer = tokenizer 54 | self.batch_size = batch_size 55 | if self.tokenizer.pad_token is None: 56 | logger.warning( 57 | "Trying to init an ModelExecutor's tokenizer without pad_token." 58 | ) 59 | self.tokenizer.pad_token = self.tokenizer.eos_token 60 | self.model.config.pad_token_id = self.model.config.eos_token_id 61 | self.tokenizer_max_length = tokenizer_max_length 62 | self.sequence_max_length = sequence_max_length 63 | if self.sequence_max_length is None: 64 | max_length = self.model.config.max_length 65 | logger.warning( 66 | ( 67 | "The `max_length` in `self.model.generate` will default to " 68 | f"`self.model.config.max_length` ({max_length})" 69 | " if `sequence_max_length` is `None`." 70 | ) 71 | ) 72 | self.sequence_max_length = max_length 73 | if hasattr(self.model.config, "max_position_embeddings"): 74 | max_embeddings = self.model.config.max_position_embeddings 75 | if sequence_max_length is not None and sequence_max_length > max_embeddings: 76 | logger.warning( 77 | ( 78 | f"The sequence_max_length ({sequence_max_length})" 79 | f" is larger than the max_position_embeddings ({max_embeddings})." # noqa: E501 80 | f" So the sequence_max_length will be set to {max_embeddings}." # noqa: E501 81 | ) 82 | ) 83 | self.sequence_max_length = max_embeddings 84 | 85 | @abstractmethod 86 | def make_prediction( 87 | self, 88 | test_set: datasets.Dataset, 89 | input_column: str, 90 | ) -> list[ModelOutput]: 91 | """Evaluate a model on a test set. 92 | 93 | Args: 94 | test_set: The dataset to make predictions on. 95 | input_column: The dataset column to use as input to the model. 96 | 97 | Returns: 98 | A list of model outputs, one for each element in the test set. 99 | """ 100 | 101 | @abstractmethod 102 | def make_single_prediction(self, model_input: str) -> ModelOutput: 103 | """Make prediction on one example. 104 | 105 | Args: 106 | model_input: The input string to the model. 107 | 108 | Returns: 109 | A single model output, useful for exposing a model to a user interface. 110 | """ 111 | -------------------------------------------------------------------------------- /prompt2model/model_executor/generate.py: -------------------------------------------------------------------------------- 1 | """Model executor for generative models, including T5-type and GPT-type.""" 2 | from __future__ import annotations # noqa FI58 3 | 4 | from typing import Any 5 | 6 | import datasets 7 | import torch 8 | 9 | from prompt2model.model_executor import ModelExecutor, ModelOutput 10 | from prompt2model.utils import get_formatted_logger 11 | 12 | logger = get_formatted_logger("ModelExecutor") 13 | 14 | 15 | class GenerationModelExecutor(ModelExecutor): 16 | """Model executor for T5-type and GPT-type models.""" 17 | 18 | def generate( 19 | self, 20 | input_ids: list[torch.Tensor], 21 | attention_mask: list[torch.Tensor], 22 | hyperparameter_choices: dict[str, Any], 23 | ) -> list[torch.Tensor]: 24 | """Generates sequences of token IDs using the model. 25 | 26 | Args: 27 | input_ids: A list of token ID sequences. 28 | attention_mask: A list of binary masks indicating attended tokens. 29 | hyperparameter_choices: A dictionary of hyperparameters for inference. 30 | 31 | Returns: 32 | A list of model output tensors, one for each element in input_ids. 33 | """ 34 | generate_strategy = hyperparameter_choices.get("generate_strategy", "greedy") 35 | if generate_strategy not in [ 36 | "beam", # beam search. 37 | "top_k", # top_k sampling. 38 | "top_p", # top_p sampling. 39 | "greedy", # greedy search. 40 | "intersect", # If both top_k and top_p are set, the model will 41 | # sample from the intersection of the top-k tokens and the top-p tokens. 42 | ]: 43 | raise ValueError( 44 | "Only top_k/top_p/intersect sampling and beam/greedy " 45 | "search are supported for inference." 46 | ) 47 | if generate_strategy == "greedy": 48 | output = self.model.generate( 49 | input_ids=input_ids, 50 | attention_mask=attention_mask, 51 | max_length=self.sequence_max_length, 52 | eos_token_id=self.model.config.eos_token_id, 53 | early_stopping=True, 54 | repetition_penalty=hyperparameter_choices.get( 55 | "repetition_penalty", 2.0 56 | ), 57 | ) 58 | elif generate_strategy == "beam": 59 | output = self.model.generate( 60 | input_ids=input_ids, 61 | attention_mask=attention_mask, 62 | max_length=self.sequence_max_length, 63 | eos_token_id=self.model.config.eos_token_id, 64 | early_stopping=True, 65 | do_sample=False, 66 | repetition_penalty=hyperparameter_choices.get( 67 | "repetition_penalty", 2.0 68 | ), 69 | num_beams=hyperparameter_choices.get("num_beams", 3), 70 | ) 71 | elif generate_strategy == "top_k": 72 | output = self.model.generate( 73 | input_ids=input_ids, 74 | attention_mask=attention_mask, 75 | max_length=self.sequence_max_length, 76 | eos_token_id=self.model.config.eos_token_id, 77 | early_stopping=True, 78 | do_sample=True, 79 | repetition_penalty=hyperparameter_choices.get( 80 | "repetition_penalty", 2.0 81 | ), 82 | top_k=hyperparameter_choices.get("top_k", 20), 83 | ) 84 | elif generate_strategy == "top_p": 85 | output = self.model.generate( 86 | input_ids=input_ids, 87 | attention_mask=attention_mask, 88 | max_length=self.sequence_max_length, 89 | eos_token_id=self.model.config.eos_token_id, 90 | early_stopping=True, 91 | do_sample=True, 92 | repetition_penalty=hyperparameter_choices.get( 93 | "repetition_penalty", 2.0 94 | ), 95 | top_p=hyperparameter_choices.get("top_p", 0.95), 96 | ) 97 | else: 98 | # For intersect sampling. 99 | output = self.model.generate( 100 | input_ids=input_ids, 101 | attention_mask=attention_mask, 102 | max_length=self.sequence_max_length, 103 | eos_token_id=self.model.config.eos_token_id, 104 | early_stopping=True, 105 | do_sample=True, 106 | repetition_penalty=hyperparameter_choices.get( 107 | "repetition_penalty", 2.0 108 | ), 109 | top_k=hyperparameter_choices.get("top_k", 20), 110 | top_p=hyperparameter_choices.get("top_p", 0.95), 111 | ) 112 | return output 113 | 114 | def make_prediction( 115 | self, 116 | test_set: datasets.Dataset, 117 | input_column: str, 118 | hyperparameter_choices: dict[str, Any] = {}, 119 | ) -> list[ModelOutput]: 120 | """Make predictions with a T5-type or GPT-type model on a test set. 121 | 122 | Args: 123 | test_set: The dataset to make predictions on. Note that 124 | make_single_prediction will warp single_model_input 125 | into a inference_dataset with only one element. 126 | input_column: The dataset column to use as input to the model. 127 | hyperparameter_choices: A dictionary of hyperparameter for generate. 128 | 129 | Returns: 130 | A list of model outputs, one for each element in the test set. 131 | """ 132 | expected_num_examples = len(test_set) 133 | model_outputs = [] 134 | longest_input = max(test_set[input_column], key=len) 135 | if ( 136 | self.tokenizer_max_length is not None 137 | and len(self.tokenizer.tokenize(longest_input)) > self.tokenizer_max_length 138 | ): 139 | logger.warning( 140 | ( 141 | "Truncation happened when tokenizing dataset / input string." 142 | " You should consider increasing the tokenizer_max_length." 143 | " Otherwise the truncation may lead to unexpected results." 144 | ) 145 | ) 146 | 147 | for start_idx in range(0, expected_num_examples, self.batch_size): 148 | end_idx = min(start_idx + self.batch_size, expected_num_examples) 149 | batch = datasets.Dataset.from_dict(test_set[start_idx:end_idx]) 150 | 151 | input_texts = batch[input_column] 152 | encoded_inputs = self.tokenizer.batch_encode_plus( 153 | input_texts, 154 | truncation=True, 155 | max_length=self.tokenizer_max_length, 156 | padding=True, 157 | return_tensors="pt", 158 | ) 159 | device = self.model.device 160 | input_ids = encoded_inputs["input_ids"].to(device) 161 | attention_mask = encoded_inputs["attention_mask"].to(device) 162 | output = self.generate( 163 | input_ids=input_ids, 164 | attention_mask=attention_mask, 165 | hyperparameter_choices=hyperparameter_choices, 166 | ) 167 | 168 | for idx, input_text in enumerate(input_texts): 169 | logits = output[idx] 170 | decoded_output = self.tokenizer.decode(logits, skip_special_tokens=True) 171 | model_output = ModelOutput( 172 | prediction=decoded_output, 173 | auxiliary_info={ 174 | "input_text": input_text, 175 | "logits": logits, 176 | }, 177 | ) 178 | model_outputs.append(model_output) 179 | 180 | return model_outputs 181 | 182 | def make_single_prediction( 183 | self, model_input: str, hyperparameter_choices: dict[str, Any] = {} 184 | ) -> ModelOutput: 185 | """Mock evaluation on one example. 186 | 187 | Args: 188 | model_input: The input string to the model. 189 | hyperparameter_choices: A dictionary of hyperparameter for inference. 190 | 191 | Returns: 192 | A single model output, useful for exposing a model to a user interface. 193 | """ 194 | expected_num_examples = 1 195 | inference_dataset = datasets.Dataset.from_dict({"model_input": [model_input]}) 196 | inference_column = "model_input" 197 | if len(inference_dataset) != expected_num_examples: 198 | raise ValueError( 199 | f"Expected {expected_num_examples} examples, " 200 | f"but got {len(inference_dataset)}." 201 | ) 202 | model_output = self.make_prediction( 203 | inference_dataset, 204 | inference_column, 205 | hyperparameter_choices, 206 | )[0] 207 | return model_output 208 | -------------------------------------------------------------------------------- /prompt2model/model_executor/mock.py: -------------------------------------------------------------------------------- 1 | """A dummy class to generate model outputs (for testing purposes).""" 2 | from __future__ import annotations 3 | 4 | import datasets 5 | 6 | from prompt2model.model_executor import ModelExecutor, ModelOutput 7 | 8 | 9 | class MockModelExecutor(ModelExecutor): 10 | """An interface for automatic model evaluation.""" 11 | 12 | def make_prediction( 13 | self, 14 | test_set: datasets.Dataset, 15 | input_column: str, 16 | ) -> list[ModelOutput]: 17 | """Mock the execution of a model on a test set. 18 | 19 | Args: 20 | test_set: The dataset to make predictions on. 21 | input_column: The dataset column to use as input to the model. 22 | 23 | Returns: 24 | An object containing model outputs. 25 | """ 26 | predictions = [] 27 | for _ in test_set[input_column]: 28 | model_output = ModelOutput(prediction="", auxiliary_info={}) 29 | predictions.append(model_output) 30 | return predictions 31 | 32 | def make_single_prediction(self, model_input: str) -> ModelOutput: 33 | """Mock evaluation on one example. 34 | 35 | Args: 36 | model_input: The input string to the model. 37 | 38 | Returns: 39 | A single model output, useful for exposing a model to a user interface. 40 | """ 41 | _ = model_input 42 | model_output = ModelOutput(prediction="", auxiliary_info={}) 43 | return model_output 44 | -------------------------------------------------------------------------------- /prompt2model/model_executor/readme.md: -------------------------------------------------------------------------------- 1 | # Model Executor 2 | 3 | ## Overview 4 | 5 | - **ModelExecutor**: An interface for executing predictions across 6 | various models. 7 | - **GenerationModelExecutor**: Tailored for generative models such as 8 | T5 (encoder-decoder) and GPT (autoregressive). It supports multiple 9 | generation strategies. 10 | - **ModelOutput**: Denotes the output from the model for a given 11 | example, consolidating the prediction and any auxiliary information. 12 | 13 | ## Getting Started 14 | 15 | - **Import the Necessary Modules**: 16 | 17 | ```python 18 | from prompt2model.model_executor import GenerationModelExecutor, ModelOutput 19 | ``` 20 | 21 | - **Prepare the Input Data**: 22 | 23 | ```python 24 | input_dataset = ... # A dataset containing input examples. 25 | input_example = ... # A singular input in string format. 26 | ``` 27 | 28 | - **Initialize the ModelExecutor**: 29 | 30 | ```python 31 | model = ... # A HuggingFace model instance. 32 | tokenizer = ... # A corresponding HuggingFace tokenizer. 33 | model_executor = GenerationModelExecutor(model, tokenizer) 34 | ``` 35 | 36 | - **Generate Predictions**: 37 | 38 | For multiple inputs: 39 | 40 | ```python 41 | outputs = model_executor.make_prediction( 42 | test_set=input_dataset, # A dataset object. 43 | input_column="..." 44 | # The input column is the name of the column containing the input in the input_dataset. 45 | ) 46 | ``` 47 | 48 | For a single input: 49 | 50 | ```python 51 | test_input = "..." 52 | output = model_executor.make_single_prediction(test_input) 53 | ``` 54 | 55 | - **Choose a Generation Strategy**: 56 | 57 | Specify the desired decoding strategy. For more details, see the 58 | document string `GenerationModelExecutor.generate`. 59 | 60 | ```python 61 | hyperparameters = {"generate_strategy": "beam", "num_beams": 4} 62 | model_output = model_executor.make_single_prediction(test_input, hyperparameters) 63 | ``` 64 | -------------------------------------------------------------------------------- /prompt2model/model_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | """Import all the model executor classes.""" 2 | 3 | from prompt2model.model_retriever.base import ModelRetriever 4 | from prompt2model.model_retriever.description_based_retriever import ( 5 | DescriptionModelRetriever, 6 | ) 7 | from prompt2model.model_retriever.mock import MockModelRetriever 8 | 9 | __all__ = ("ModelRetriever", "DescriptionModelRetriever", "MockModelRetriever") 10 | -------------------------------------------------------------------------------- /prompt2model/model_retriever/base.py: -------------------------------------------------------------------------------- 1 | """An interface for model selection.""" 2 | from __future__ import annotations 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | from prompt2model.prompt_parser import PromptSpec 7 | 8 | 9 | # pylint: disable=too-few-public-methods 10 | class ModelRetriever(ABC): 11 | """Retrieve several models from HuggingFace.""" 12 | 13 | @abstractmethod 14 | def retrieve( 15 | self, 16 | prompt: PromptSpec, 17 | ) -> list[str]: 18 | """Retrieve relevant models from HuggingFace. 19 | 20 | Args: 21 | prompt: A prompt to use to select relevant models. 22 | 23 | Return: 24 | A list of relevant models' HuggingFace names. 25 | """ 26 | -------------------------------------------------------------------------------- /prompt2model/model_retriever/generate_hypothetical_document.py: -------------------------------------------------------------------------------- 1 | """Tools for generating hypothetical documents from prompts.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | import logging 6 | 7 | from prompt2model.prompt_parser import PromptSpec 8 | from prompt2model.utils import API_ERRORS, api_tools, handle_api_error 9 | 10 | PROMPT_PREFIX = """HuggingFace contains models, which are each given a user-generated description. The first section of the description, delimited with two "---" lines, consists of a YAML description of the model. This may contain fields like "language" (supported by model), "datasets" (used to train the model), "tags" (e.g. tasks relevant to the model), and "metrics" (used to evaluate the model). Create a hypothetical HuggingFace model description that would satisfy a given user instruction. Here are some examples: 11 | 12 | Instruction: "Give me some translation from English to Vietnamese. Input English and output Vietnamese." 13 | Hypothetical model description: 14 | --- 15 | language: 16 | - en 17 | - vi 18 | 19 | tags: 20 | - translation 21 | 22 | license: apache-2.0 23 | --- 24 | 25 | ### eng-vie 26 | 27 | * source group: English 28 | * target group: Vietnamese 29 | * OPUS readme: [eng-vie](https://github.com/Helsinki-NLP/Tatoeba-Challenge/tree/master/models/eng-vie/README.md) 30 | 31 | * model: transformer-align 32 | * source language(s): eng 33 | * target language(s): vie vie_Hani 34 | * model: transformer-align 35 | * pre-processing: normalization + SentencePiece (spm32k,spm32k) 36 | * a sentence initial language token is required in the form of `>>id<<` (id = valid target language ID) 37 | * download original weights: [opus-2020-06-17.zip](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.zip) 38 | * test set translations: [opus-2020-06-17.test.txt](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.test.txt) 39 | * test set scores: [opus-2020-06-17.eval.txt](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.eval.txt) 40 | 41 | ## Benchmarks 42 | 43 | | testset | BLEU | chr-F | 44 | |-----------------------|-------|-------| 45 | | Tatoeba-test.eng.vie | 37.2 | 0.542 | 46 | 47 | 48 | ### System Info: 49 | - hf_name: eng-vie 50 | 51 | - source_languages: eng 52 | 53 | - target_languages: vie 54 | 55 | - opus_readme_url: https://github.com/Helsinki-NLP/Tatoeba-Challenge/tree/master/models/eng-vie/README.md 56 | 57 | - original_repo: Tatoeba-Challenge 58 | 59 | - tags: ['translation'] 60 | 61 | - languages: ['en', 'vi'] 62 | 63 | - src_constituents: {'eng'} 64 | 65 | - tgt_constituents: {'vie', 'vie_Hani'} 66 | 67 | - src_multilingual: False 68 | 69 | - tgt_multilingual: False 70 | 71 | - prepro: normalization + SentencePiece (spm32k,spm32k) 72 | 73 | - src_alpha3: eng 74 | 75 | - tgt_alpha3: vie 76 | 77 | - short_pair: en-vi 78 | 79 | - chrF2_score: 0.542 80 | 81 | - bleu: 37.2 82 | 83 | - brevity_penalty: 0.973 84 | 85 | - ref_len: 24427.0 86 | 87 | - src_name: English 88 | 89 | - tgt_name: Vietnamese 90 | 91 | - train_date: 2020-06-17 92 | 93 | - src_alpha2: en 94 | 95 | - tgt_alpha2: vi 96 | 97 | - prefer_old: False 98 | 99 | - long_pair: eng-vie 100 | 101 | 102 | Instruction: "I want to summarize things like news articles." 103 | Hypothetical model description: 104 | --- 105 | language: en 106 | license: apache-2.0 107 | tags: 108 | - pegasus 109 | - seq2seq 110 | - summarization 111 | model-index: 112 | - name: tuner007/pegasus_summarizer 113 | results: 114 | - task: 115 | type: summarization 116 | name: Summarization 117 | dataset: 118 | name: cnn_dailymail 119 | type: cnn_dailymail 120 | config: 3.0.0 121 | split: train 122 | metrics: 123 | - name: ROUGE-1 124 | type: rouge 125 | value: 36.604 126 | verified: true 127 | - name: ROUGE-2 128 | type: rouge 129 | value: 14.6398 130 | verified: true 131 | - name: ROUGE-L 132 | type: rouge 133 | value: 23.8845 134 | verified: true 135 | --- 136 | 137 | ## Model description 138 | [PEGASUS](https://github.com/google-research/pegasus) fine-tuned for summarization 139 | 140 | > Created by [Arpit Rajauria](https://twitter.com/arpit_rajauria) 141 | [![Twitter icon](https://cdn0.iconfinder.com/data/icons/shift-logotypes/32/Twitter-32.png)](https://twitter.com/arpit_rajauria) 142 | 143 | 144 | ### Framework versions 145 | 146 | - Transformers 4.31.0 147 | - Pytorch 2.0.1+cu118 148 | - Datasets 2.13.1 149 | - Tokenizers 0.13.3 150 | 151 | 152 | Instruction: "I want to classify sentences by their sentiment (positive/negative/neutral)." 153 | Hypothetical model description: 154 | --- 155 | language: en 156 | license: apache-2.0 157 | datasets: 158 | - sst2 159 | - glue 160 | model-index: 161 | - name: distilbert-base-uncased-finetuned-sst-2-english 162 | results: 163 | - task: 164 | type: text-classification 165 | name: Text Classification 166 | dataset: 167 | name: glue 168 | type: glue 169 | config: sst2 170 | split: validation 171 | metrics: 172 | - type: accuracy 173 | value: 0.9105504587155964 174 | name: Accuracy 175 | verified: true 176 | verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiN2YyOGMxYjY2Y2JhMjkxNjIzN2FmMjNiNmM2ZWViNGY3MTNmNWI2YzhiYjYxZTY0ZGUyN2M1NGIxZjRiMjQwZiIsInZlcnNpb24iOjF9.uui0srxV5ZHRhxbYN6082EZdwpnBgubPJ5R2-Wk8HTWqmxYE3QHidevR9LLAhidqGw6Ih93fK0goAXncld_gBg 177 | --- 178 | 179 | # DistilBERT base uncased finetuned SST-2 180 | 181 | ## Table of Contents 182 | - [Model Details](#model-details) 183 | - [How to Get Started With the Model](#how-to-get-started-with-the-model) 184 | - [Uses](#uses) 185 | - [Risks, Limitations and Biases](#risks-limitations-and-biases) 186 | - [Training](#training) 187 | 188 | ## Model Details 189 | **Model Description:** This model is a fine-tune checkpoint of [DistilBERT-base-uncased](https://huggingface.co/distilbert-base-uncased), fine-tuned on SST-2. 190 | This model reaches an accuracy of 91.3 on the dev set (for comparison, Bert bert-base-uncased version reaches an accuracy of 92.7). 191 | - **Developed by:** Hugging Face 192 | - **Model Type:** Text Classification 193 | - **Language(s):** English 194 | - **License:** Apache-2.0 195 | - **Parent Model:** For more details about DistilBERT, we encourage users to check out [this model card](https://huggingface.co/distilbert-base-uncased). 196 | - **Resources for more information:** 197 | - [Model Documentation](https://huggingface.co/docs/transformers/main/en/model_doc/distilbert#transformers.DistilBertForSequenceClassification) 198 | - [DistilBERT paper](https://arxiv.org/abs/1910.01108) 199 | 200 | ## Uses 201 | 202 | #### Direct Use 203 | 204 | This model can be used for topic classification. You can use the raw model for either masked language modeling or next sentence prediction, but it's mostly intended to be fine-tuned on a downstream task. See the model hub to look for fine-tuned versions on a task that interests you. 205 | 206 | #### Misuse and Out-of-scope Use 207 | The model should not be used to intentionally create hostile or alienating environments for people. In addition, the model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model. 208 | 209 | 210 | ## Risks, Limitations and Biases 211 | 212 | Based on a few experimentations, we observed that this model could produce biased predictions that target underrepresented populations. 213 | 214 | For instance, for sentences like `This film was filmed in COUNTRY`, this binary classification model will give radically different probabilities for the positive label depending on the country (0.89 if the country is France, but 0.08 if the country is Afghanistan) when nothing in the input indicates such a strong semantic shift. 215 | 216 | # Training 217 | 218 | 219 | #### Training Data 220 | 221 | 222 | The authors use the following Stanford Sentiment Treebank([sst2](https://huggingface.co/datasets/sst2)) corpora for the model. 223 | ``` 224 | :""" # noqa: E501 225 | 226 | 227 | def generate_hypothetical_model_description( 228 | prompt: PromptSpec, max_api_calls: int = None 229 | ) -> str: 230 | """Generate a hypothetical model description for the user's instruction. 231 | 232 | This method is based on HyDE by Gao et al 2022 (https://arxiv.org/abs/2212.10496). 233 | 234 | Args: 235 | prompt: PromptSpec object containing the user's instruction. 236 | 237 | Returns: 238 | a hypothetical model description for the user's instruction. 239 | """ 240 | if max_api_calls and max_api_calls <= 0: 241 | raise ValueError("max_api_calls must be > 0.") 242 | api_call_counter = 0 243 | 244 | instruction = prompt.instruction 245 | api_agent = api_tools.default_api_agent 246 | chatgpt_prompt = ( 247 | PROMPT_PREFIX 248 | + "\n" 249 | + f'Instruction: "{instruction}"\nHypothetical model description:\n' 250 | ) 251 | while True: 252 | try: 253 | chatgpt_completion = api_agent.generate_one_completion( 254 | chatgpt_prompt, 255 | temperature=0.0, 256 | presence_penalty=0.0, 257 | frequency_penalty=0.0, 258 | ) 259 | return chatgpt_completion.choices[0]["message"]["content"] 260 | except API_ERRORS as e: 261 | handle_api_error(e) 262 | api_call_counter += 1 263 | if max_api_calls and api_call_counter >= max_api_calls: 264 | logging.error("Maximum number of API calls reached.") 265 | raise ValueError("Maximum number of API calls reached.") from e 266 | -------------------------------------------------------------------------------- /prompt2model/model_retriever/mock.py: -------------------------------------------------------------------------------- 1 | """An interface for model selection.""" 2 | from __future__ import annotations 3 | 4 | from prompt2model.model_retriever import ModelRetriever 5 | from prompt2model.prompt_parser import PromptSpec 6 | 7 | 8 | class MockModelRetriever(ModelRetriever): 9 | """Select a fixed model from among a set of hyperparameter choices.""" 10 | 11 | def __init__(self, fixed_model_name: str): 12 | """Initialize a dummy retriever that returns a fixed model name.""" 13 | self.fixed_model_name = fixed_model_name 14 | 15 | def retrieve( 16 | self, 17 | prompt: PromptSpec, 18 | ) -> list[str]: 19 | """Select an arbitrary, fixed model from HuggingFace. 20 | 21 | Args: 22 | prompt: A prompt to use to select relevant models. 23 | 24 | Return: 25 | A relevant model's HuggingFace name. 26 | """ 27 | return [self.fixed_model_name] 28 | -------------------------------------------------------------------------------- /prompt2model/model_retriever/readme.md: -------------------------------------------------------------------------------- 1 | # Model Retriever 2 | 3 | ## Overview 4 | 5 | - `ModelRetriever`: An interface for retrieving several models from 6 | HuggingFace. 7 | - `DescriptionModelRetriever` offers functions to vectorize 8 | descriptions, adjust relevance scores, build a BM25 index, and 9 | retrieve models based on a prompt. 10 | - `ModelInfo` stores a model's HuggingFace name, description, 11 | identifier, relevance score, disk size, and download count. 12 | 13 | ## Getting Started 14 | 15 | - **Import Required Modules**: 16 | 17 | ```python 18 | from prompt2model.model_retriever import DescriptionModelRetriever 19 | from prompt2model.prompt_parser import MockPromptSpec, TaskType 20 | ``` 21 | 22 | - **Initialize the Prompt**: 23 | 24 | ```python 25 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) 26 | prompt = "..." 27 | prompt_spec._instruction = prompt 28 | ``` 29 | 30 | - **Initialize and Run the Retriever**: 31 | 32 | ```python 33 | retriever = DescriptionModelRetriever( 34 | search_index_path="path_to_bm25_search_index.pkl", 35 | model_descriptions_index_path="path_to_model_info_directory", 36 | use_bm25=True, 37 | use_HyDE=True, 38 | ) 39 | top_model_name = retriever.retrieve(prompt_spec) 40 | ``` 41 | 42 | ## Notes 43 | 44 | - Ensure that the paths provided to the retriever (e.g., 45 | `search_index_path` and `model_descriptions_index_path`) point to the 46 | correct locations. 47 | -------------------------------------------------------------------------------- /prompt2model/model_retriever/run_model_retriever.py: -------------------------------------------------------------------------------- 1 | """Script to run the model retriever in isolation.""" 2 | 3 | from prompt2model.model_retriever import DescriptionModelRetriever 4 | from prompt2model.prompt_parser import MockPromptSpec, TaskType 5 | 6 | if __name__ == "__main__": 7 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) 8 | prompt = """Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.""" # noqa E501 9 | prompt_spec._instruction = prompt 10 | retriever = DescriptionModelRetriever( 11 | search_index_path="huggingface_data/huggingface_models/bm25_search_index.pkl", 12 | model_descriptions_index_path="huggingface_data/huggingface_models/model_info/", 13 | use_bm25=True, 14 | use_HyDE=True, 15 | ) 16 | top_model_name = retriever.retrieve(prompt_spec) 17 | -------------------------------------------------------------------------------- /prompt2model/model_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """Import BaseTrainer classes.""" 2 | from prompt2model.model_trainer.base import BaseTrainer 3 | from prompt2model.model_trainer.generate import GenerationModelTrainer 4 | from prompt2model.model_trainer.mock import MockTrainer 5 | 6 | __all__ = ("MockTrainer", "BaseTrainer", "GenerationModelTrainer") 7 | -------------------------------------------------------------------------------- /prompt2model/model_trainer/base.py: -------------------------------------------------------------------------------- 1 | """An base class for trainers.""" 2 | from __future__ import annotations # noqa FI58 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import Any 6 | 7 | import datasets 8 | import transformers 9 | from transformers import AutoModel, AutoTokenizer 10 | 11 | 12 | # pylint: disable=too-few-public-methods 13 | class BaseTrainer(ABC): 14 | """Train a model with a fixed set of hyperparameters.""" 15 | 16 | def __init__(self, pretrained_model_name: str): 17 | """Initialize a model trainer. 18 | 19 | Args: 20 | pretrained_model_name: A HuggingFace model name to use for training. 21 | """ 22 | self.model = AutoModel.from_pretrained(pretrained_model_name) 23 | self.tokenizer = AutoTokenizer.from_pretrained( 24 | pretrained_model_name, padding_side="left" 25 | ) 26 | self.wandb = None 27 | 28 | @abstractmethod 29 | def train_model( 30 | self, 31 | hyperparameter_choices: dict[str, Any], 32 | training_datasets: list[datasets.Dataset], 33 | validation_datasets: list[datasets.Dataset] | None = None, 34 | ) -> tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]: 35 | """Train a model with the given hyperparameters and return it.""" 36 | -------------------------------------------------------------------------------- /prompt2model/model_trainer/callback.py: -------------------------------------------------------------------------------- 1 | """The real evaluation will be conduted after each mock evaluation of Trainer.""" 2 | 3 | 4 | from transformers import TrainerCallback 5 | 6 | from prompt2model.model_evaluator import Seq2SeqEvaluator 7 | from prompt2model.model_executor import GenerationModelExecutor 8 | from prompt2model.utils import get_formatted_logger 9 | 10 | logger = get_formatted_logger("ModelTrainer") 11 | 12 | 13 | class ValidationCallback(TrainerCallback): 14 | """The real evaluation will be conduted after each mock evaluation of Trainer.""" 15 | 16 | def __init__( 17 | self, 18 | trainer, 19 | tokenizer, 20 | val_dataset, 21 | executor_batch_size=10, 22 | tokenizer_max_length=256, 23 | sequence_max_length=512, 24 | ) -> None: 25 | """Initializes a new instance of Model Trainer Callback. 26 | 27 | Args: 28 | trainer: Trainer instance. 29 | After each epoch of Training, this callback will be called. 30 | tokenizer: Tokenizer to initialize model executor. 31 | val_dataset: Validation dataset to be evaluated on. 32 | executor_batch_size: The batch size for model executor to 33 | make predictions. 34 | tokenizer_max_length: The maximum number of tokens that 35 | tokenizer is allowed to generate. 36 | sequence_max_length: The maximum number of tokens in 37 | the input and output. 38 | """ 39 | super().__init__() 40 | self.trainer = trainer 41 | self.tokenizer = tokenizer 42 | self.val_dataset = val_dataset 43 | self.epoch_count = 0 44 | self.val_dataset_size = len(self.val_dataset) 45 | self.executor_batch_size = executor_batch_size 46 | self.tokenizer_max_length = tokenizer_max_length 47 | self.sequence_max_length = sequence_max_length 48 | 49 | def on_epoch_end(self, args, state, control, **kwargs): 50 | """After each evaluation, this function will be called.""" 51 | _ = (args, state, control, kwargs) 52 | # Suppress the unused parameters warning. 53 | self.epoch_count += 1 54 | logger.info( 55 | f"Epoch: {self.epoch_count}. Evaluate on { self.val_dataset_size} examples." 56 | ) 57 | # For multi-GPU training, the training processor will be segmented 58 | # into multi-threads with data paralyzation, so the validation dataset 59 | # used in the callback is also segmented. 60 | model_executor = GenerationModelExecutor( 61 | model=self.trainer.model, 62 | tokenizer=self.tokenizer, 63 | batch_size=self.executor_batch_size, 64 | tokenizer_max_length=self.tokenizer_max_length, 65 | sequence_max_length=self.sequence_max_length, 66 | ) 67 | model_outputs = model_executor.make_prediction( 68 | self.val_dataset, 69 | "model_input", 70 | ) 71 | evaluator = Seq2SeqEvaluator() 72 | metric_values = evaluator.evaluate_model( 73 | self.val_dataset, 74 | "model_output", 75 | model_outputs, 76 | encoder_model_name="xlm-roberta-base", 77 | ) 78 | logger.info(metric_values) 79 | -------------------------------------------------------------------------------- /prompt2model/model_trainer/mock.py: -------------------------------------------------------------------------------- 1 | """This module provides a dummy trainer for testing purposes.""" 2 | from __future__ import annotations # noqa FI58 3 | 4 | from typing import Any 5 | 6 | import datasets 7 | from transformers import PreTrainedModel # noqa 8 | from transformers import PreTrainedTokenizer 9 | 10 | from prompt2model.model_trainer import BaseTrainer 11 | 12 | 13 | class MockTrainer(BaseTrainer): 14 | """This dummy trainer does not actually train anything.""" 15 | 16 | def train_model( 17 | self, 18 | hyperparameter_choices: dict[str, Any], 19 | training_datasets: list[datasets.Dataset], 20 | validation_datasets: list[datasets.Dataset] | None = None, 21 | ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: 22 | """This dummy trainer returns the given model without any training. 23 | 24 | Args: 25 | training_datasets: A list of training datasets. 26 | hyperparameter_choices: A dictionary of hyperparameters for training. 27 | 28 | Returns: 29 | A HuggingFace model and tokenizer. 30 | """ 31 | _ = training_datasets, hyperparameter_choices, validation_datasets 32 | return self.model, self.tokenizer 33 | -------------------------------------------------------------------------------- /prompt2model/model_trainer/readme.md: -------------------------------------------------------------------------------- 1 | # Model Trainer 2 | 3 | ## Overview 4 | 5 | - `BaseTrainer`: A base class for model training. 6 | - `GenerationModelTrainer`: A specialized class for training T5-type 7 | (encoder-decoder) and GPT-type (decoder-only) models. 8 | - `ValidationCallback`: A class for performing validation during 9 | training. 10 | 11 | ## Getting Started 12 | 13 | ### Import Modules 14 | 15 | ```python 16 | from prompt2model.model_trainer.generate import GenerationModelTrainer 17 | ``` 18 | 19 | ### Initialize the Trainer 20 | 21 | ```python 22 | pretrained_model_name = "..." # Replace with a HuggingFace pretrained model name. 23 | has_encoder = 24 | # Set to True if the model has an encoder, otherwise False. 25 | executor_batch_size = # Set the batch size for model validation. 26 | tokenizer_max_length = # Set the maximum length for the tokenizer. 27 | sequence_max_length = # Set the maximum sequence length. 28 | trainer = GenerationModelTrainer( 29 | pretrained_model_name, 30 | has_encoder, 31 | executor_batch_size, 32 | tokenizer_max_length, 33 | sequence_max_length 34 | ) 35 | # For more details, refer to the docstring of GenerationModelTrainer. 36 | ``` 37 | 38 | ### Prepare Dataset and Hyperparameters 39 | 40 | ```python 41 | training_datasets = [] # A list of training datasets. 42 | validation_datasets = [] # A list of validation datasets. 43 | hyperparameter_choices = {...} # A dictionary of hyperparameters. 44 | # For more details, refer to the doc string of GenerationModelTrainer.train_model. 45 | ``` 46 | 47 | ### Train the Model 48 | 49 | ```python 50 | trained_model, trained_tokenizer = trainer.train_model( 51 | hyperparameter_choices, 52 | training_datasets, 53 | validation_datasets 54 | ) 55 | ``` 56 | -------------------------------------------------------------------------------- /prompt2model/param_selector/__init__.py: -------------------------------------------------------------------------------- 1 | """Import model selector classes.""" 2 | from prompt2model.param_selector.base import ParamSelector 3 | from prompt2model.param_selector.mock import MockParamSelector 4 | from prompt2model.param_selector.search_with_optuna import OptunaParamSelector 5 | 6 | __all__ = ("MockParamSelector", "ParamSelector", "OptunaParamSelector") 7 | -------------------------------------------------------------------------------- /prompt2model/param_selector/base.py: -------------------------------------------------------------------------------- 1 | """An interface for model selection.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Any 7 | 8 | import datasets 9 | import transformers 10 | 11 | 12 | # pylint: disable=too-few-public-methods 13 | class ParamSelector(ABC): 14 | """Select a good model from among a set of hyperparameter choices.""" 15 | 16 | @abstractmethod 17 | def select_from_hyperparameters( 18 | self, 19 | training_sets: list[datasets.Dataset], 20 | validation: datasets.Dataset, 21 | hyperparameters: dict[str, list[Any]], 22 | ) -> tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]: 23 | """Select a model among a set of hyperparameters (given or inferred). 24 | 25 | Args: 26 | training_sets: One or more training datasets for the trainer. 27 | validation: A dataset for computing validation metrics. 28 | hyperparameters: A dictionary of hyperparameter choices. 29 | 30 | Return: 31 | A model and tokenizer (with hyperparameters from given range). 32 | """ 33 | -------------------------------------------------------------------------------- /prompt2model/param_selector/mock.py: -------------------------------------------------------------------------------- 1 | """Mock model selector for testing purposes.""" 2 | from __future__ import annotations 3 | 4 | from typing import Any 5 | 6 | import datasets 7 | import transformers 8 | 9 | from prompt2model.model_trainer import BaseTrainer 10 | from prompt2model.param_selector.base import ParamSelector 11 | from prompt2model.prompt_parser import PromptSpec 12 | 13 | 14 | class MockParamSelector(ParamSelector): 15 | """Uses a default set of parameters.""" 16 | 17 | def __init__(self, trainer: BaseTrainer): 18 | """Initialize with train/val datasets and a prompt specification. 19 | 20 | Args: 21 | trainer: A trainer to use for training models during model selection. 22 | """ 23 | self.trainer = trainer 24 | 25 | def _example_hyperparameter_choices(self) -> dict[str, Any]: 26 | """Example hyperparameters (for testing only).""" 27 | return { 28 | "optimizer": "AdamW", 29 | "learning_rate": 1e-4, 30 | } 31 | 32 | def select_from_hyperparameters( 33 | self, 34 | training_sets: list[datasets.Dataset], 35 | validation: datasets.Dataset, 36 | hyperparameters: dict[str, list[Any]], 37 | ) -> tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]: 38 | """Use a pre-defined default set of hyperparameters. 39 | 40 | Args: 41 | training_sets: One or more training datasets for the trainer. 42 | validation: A dataset for computing validation metrics. 43 | prompt_spec: (Optional) A prompt to infer hyperparameters from. 44 | hyperparameters: (Optional) A dictionary of hyperparameter choices. 45 | 46 | Return: 47 | A model and tokenizer (trained using default hyperparameters). 48 | """ 49 | single_model = self.trainer.train_model( 50 | self._example_hyperparameter_choices(), training_sets 51 | ) 52 | return single_model 53 | 54 | def select_from_spec( 55 | self, 56 | training_sets: list[datasets.Dataset], 57 | validation: datasets.Dataset, 58 | prompt_spec: PromptSpec, 59 | ) -> tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]: 60 | """The MockParamSelector cannot infer hyperparameters from the spec. 61 | 62 | Args: 63 | training_sets: One or more training datasets for the trainer. 64 | validation: A dataset for computing validation metrics. 65 | prompt_spec: (Optional) A prompt to infer hyperparameters from. 66 | hyperparameters: (Optional) A dictionary of hyperparameter choices. 67 | """ 68 | raise NotImplementedError 69 | -------------------------------------------------------------------------------- /prompt2model/param_selector/search_with_optuna.py: -------------------------------------------------------------------------------- 1 | """This module provides automatic hyperparameter selection using Optuna.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | import os 6 | from pathlib import Path 7 | from typing import Any, Optional 8 | 9 | import optuna 10 | import transformers 11 | from datasets import Dataset, concatenate_datasets 12 | from optuna.trial import Trial 13 | from transformers import PreTrainedModel # noqa 14 | from transformers import PreTrainedTokenizer 15 | 16 | from prompt2model.model_trainer.generate import GenerationModelTrainer 17 | from prompt2model.param_selector.base import ParamSelector 18 | from prompt2model.utils.config import DEFAULT_HYPERPARAMETERS_SPACE 19 | 20 | 21 | class OptunaParamSelector(ParamSelector): 22 | """Uses Optuna for searching for hyperparameters.""" 23 | 24 | def __init__(self, trainer: GenerationModelTrainer, n_trials: int): 25 | """Initializes a new instance of OptunaParamSelector. 26 | 27 | Args: 28 | trainer (BaseTrainer): trainer object from GenerationModelTrainer 29 | n_trials (int): The maximum number of parameter configurations to evaluate 30 | during conducting hyperparameter search. 31 | """ 32 | self.generation_model_trainer = trainer 33 | self.n_trials = n_trials 34 | 35 | def optimize_hyperparameters( 36 | self, 37 | training_datasets: list[Dataset], 38 | validation: Dataset, 39 | hyperparameters: Optional[dict[str, Any]] = None, 40 | ) -> dict[str, Any]: 41 | """Select a model among a set of hyperparameters (given or inferred). 42 | 43 | Args: 44 | training_datasets (list[Dataset]): One or more training datasets 45 | to use for training models. 46 | validation_sets (Dataset): A dataset for computing validation metrics. 47 | hyperparameter_space (Optional[dict[str, Any]], optional): The set 48 | of possible values of hyperparaneters values required for doing 49 | optimal hyperparameter search. Defaults to None. 50 | 51 | Returns: 52 | Returns a dict which contains the best hyperparameters. 53 | """ 54 | supported_hp_space_keys = set(DEFAULT_HYPERPARAMETERS_SPACE.keys()) 55 | if hyperparameters is not None: 56 | assert set(hyperparameters.keys()).issubset( 57 | supported_hp_space_keys 58 | ), f"Only support {supported_hp_space_keys} as training parameters." 59 | hyperparameter_space = self._build_hp_space(hyperparameters) 60 | 61 | concatenated_training_dataset = concatenate_datasets(training_datasets) 62 | train_dataset = self.generation_model_trainer.tokenize_dataset( 63 | concatenated_training_dataset 64 | ) 65 | 66 | if isinstance(validation, list): 67 | validation = concatenate_datasets(validation) 68 | validation = self.generation_model_trainer.tokenize_dataset(validation) 69 | 70 | def objective(trial: Trial) -> float: 71 | model = self.generation_model_trainer.model 72 | training_args = transformers.TrainingArguments( 73 | output_dir="./checkpoint", 74 | learning_rate=trial.suggest_loguniform( 75 | "learning_rate", 76 | low=hyperparameter_space["min_learning_rate"], 77 | high=hyperparameter_space["max_learning_rate"], 78 | ), 79 | weight_decay=trial.suggest_loguniform( 80 | "weight_decay", 81 | low=hyperparameter_space["min_weight_decay"], 82 | high=hyperparameter_space["max_weight_decay"], 83 | ), 84 | num_train_epochs=trial.suggest_int( 85 | "num_train_epochs", 86 | low=hyperparameter_space["min_num_train_epochs"], 87 | high=hyperparameter_space["max_num_train_epochs"], 88 | ), 89 | ) 90 | objective_trainer = transformers.Trainer( 91 | model=model, 92 | args=training_args, 93 | data_collator=transformers.DataCollatorForSeq2Seq( 94 | tokenizer=self.generation_model_trainer.tokenizer 95 | ), 96 | train_dataset=train_dataset, 97 | eval_dataset=validation, 98 | ) 99 | 100 | _ = objective_trainer.train() 101 | optimization_targets = objective_trainer.evaluate() 102 | return optimization_targets["eval_loss"] 103 | 104 | study = optuna.create_study( 105 | study_name="automatic_hyperparameter_search", direction="minimize" 106 | ) 107 | 108 | study.optimize(func=objective, n_trials=self.n_trials, gc_after_trial=True) 109 | best_hyperparameters = { 110 | "learning_rate": float(study.best_params["learning_rate"]), 111 | "weight_decay": float(study.best_params["weight_decay"]), 112 | "num_train_epochs": int(study.best_params["num_train_epochs"]), 113 | } 114 | return best_hyperparameters 115 | 116 | def select_from_hyperparameters( 117 | self, 118 | training_datasets: list[Dataset], 119 | validation: Dataset, 120 | hyperparameters: Optional[dict[str, Any]] = None, 121 | ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: 122 | """Select a model among a set of hyperparameters (given or inferred). # noqa D410 123 | 124 | Args: 125 | training_datasets: One or more training datasets for the trainer. 126 | validation: A dataset for computing validation metrics. 127 | hyperparameters: A dictionary of hyperparameter choices. 128 | 129 | If no hyperparameter_space is specified, then the default hyperparameter_space 130 | will be choosen. Here is the example of how the space looks like: 131 | hyperparameter_space = { 132 | "min_num_train_epochs": 5, 133 | "max_num_train_epochs": 10, 134 | "save_strategy": ["epoch", "steps", "no"], 135 | "evaluation_strategy": ["epoch", "no"], 136 | "per_device_train_batch_size": [4, 8, 16, 32], 137 | "min_weight_decay": 1e-5, 138 | "max_weight_decay": 1e-1, 139 | "min_learning_rate": 1e-5, 140 | "max_learning_rate": 1e-1, 141 | } 142 | Return: 143 | A model and tokenizer (with hyperparameters from given range). 144 | """ 145 | model = self.generation_model_trainer.model 146 | tokenizer = self.generation_model_trainer.tokenizer 147 | best_model_path = Path("result/trained_model") 148 | if not os.path.exists(best_model_path): 149 | os.makedirs(best_model_path, exist_ok=True) 150 | 151 | best_hyperparameters = self.optimize_hyperparameters( 152 | training_datasets=training_datasets, 153 | validation=validation, 154 | hyperparameters=hyperparameters, 155 | ) 156 | final_hyperparameters = { 157 | "output_dir": "./best_model_checkpoint", 158 | **best_hyperparameters, 159 | } 160 | 161 | model, tokenizer = self.generation_model_trainer.train_model( 162 | hyperparameter_choices=final_hyperparameters, 163 | training_datasets=training_datasets, 164 | ) 165 | 166 | model.save_pretrained(best_model_path) 167 | tokenizer.save_pretrained(best_model_path) 168 | return model, tokenizer 169 | 170 | def _build_hp_space( 171 | self, hyperparameter_space: Optional[dict[str, Any]] = None 172 | ) -> dict[str, Any]: 173 | if hyperparameter_space is None: 174 | return DEFAULT_HYPERPARAMETERS_SPACE 175 | hp_space = {} 176 | 177 | default_keys = list(DEFAULT_HYPERPARAMETERS_SPACE.keys()) 178 | for key in list(hyperparameter_space.keys()): 179 | if key not in default_keys: 180 | print( 181 | f"Key {key} is not present in DEFAULT_HYPERPARAMETERS_SPACE. Hence, it will be ignored.", # noqa E501 182 | "However, you can expose the key to the Trainer by adding it to DEFAULT_HYPERPARAMETERS_SPACE.", # noqa E501 183 | ) 184 | 185 | for key, default_value in DEFAULT_HYPERPARAMETERS_SPACE.items(): 186 | hp_space[key] = hyperparameter_space.get(key, default_value) 187 | return hp_space 188 | -------------------------------------------------------------------------------- /prompt2model/prompt_parser/__init__.py: -------------------------------------------------------------------------------- 1 | """Import PromptSpec classes.""" 2 | from prompt2model.prompt_parser.base import PromptSpec, TaskType 3 | from prompt2model.prompt_parser.instr_parser import PromptBasedInstructionParser 4 | from prompt2model.prompt_parser.mock import MockPromptSpec 5 | 6 | __all__ = ( 7 | "PromptSpec", 8 | "TaskType", 9 | "MockPromptSpec", 10 | "PromptBasedInstructionParser", 11 | ) 12 | -------------------------------------------------------------------------------- /prompt2model/prompt_parser/base.py: -------------------------------------------------------------------------------- 1 | """An interface for prompt parsing.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | from abc import ABC, abstractmethod 6 | from enum import Enum 7 | 8 | 9 | class TaskType(Enum): 10 | """High-level taxonomy of possible NLP model outputs.""" 11 | 12 | TEXT_GENERATION = 1 13 | CLASSIFICATION = 2 14 | SEQUENCE_TAGGING = 3 15 | SPAN_EXTRACTION = 4 16 | 17 | 18 | class PromptSpec(ABC): 19 | """Parse and store structured information about the prompt.""" 20 | 21 | task_type: TaskType 22 | _instruction: str | None 23 | _examples: str | None 24 | 25 | @abstractmethod 26 | def parse_from_prompt(self, prompt: str) -> None: 27 | """Populate this class by parsing a prompt.""" 28 | 29 | @property 30 | def instruction(self) -> str: 31 | """Return the natural language instruction parsed from the prompt.""" 32 | if self._instruction is None: 33 | raise ValueError("Instruction hasn't been parsed from the prompt.") 34 | return self._instruction 35 | 36 | @property 37 | def examples(self) -> str: 38 | """Return the natural language examples parsed from the prompt.""" 39 | return self._examples or "" 40 | -------------------------------------------------------------------------------- /prompt2model/prompt_parser/instr_parser.py: -------------------------------------------------------------------------------- 1 | """An interface for prompt parsing.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | import os 6 | 7 | from prompt2model.prompt_parser.base import PromptSpec, TaskType 8 | 9 | from prompt2model.prompt_parser.instr_parser_prompt import ( # isort: split 10 | construct_prompt_for_instruction_parsing, 11 | ) 12 | 13 | from prompt2model.utils.parse_responses import parse_prompt_to_fields 14 | 15 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 16 | 17 | 18 | class PromptBasedInstructionParser(PromptSpec): 19 | """Parse the prompt to separate instructions from task demonstrations.""" 20 | 21 | def __init__(self, task_type: TaskType, max_api_calls: int = 5): 22 | """Initialize the prompt spec with empty parsed fields. 23 | 24 | We initialize the "instruction" and "examples" fields with None. 25 | These fields can be populated with the parse_from_prompt method. 26 | 27 | Args: 28 | task_type: Set a constant task type to use for all prompts. 29 | max_api_calls: The maximum number of API calls allowed, 30 | or None for unlimited. 31 | """ 32 | self.task_type = task_type 33 | self._instruction: str | None = None 34 | self._examples: str | None = None 35 | self.max_api_calls = max_api_calls 36 | 37 | def parse_from_prompt(self, prompt: str) -> None: 38 | """Parse prompt into specific fields, stored as class member variables. 39 | 40 | This function directly stores the parsed fields into the class's member 41 | variables `instruction` and `examples`. So it has no return value. 42 | 43 | Args: 44 | prompt: User prompt to parse into two specific fields: 45 | "instruction" and "demonstrations". 46 | """ 47 | parsing_prompt_for_chatgpt = construct_prompt_for_instruction_parsing(prompt) 48 | required_keys = ["Instruction", "Demonstrations"] 49 | 50 | extraction = parse_prompt_to_fields( 51 | parsing_prompt_for_chatgpt, 52 | required_keys, 53 | max_api_calls=self.max_api_calls, 54 | ) 55 | self._instruction = extraction["Instruction"] 56 | self._examples = extraction["Demonstrations"] 57 | 58 | def set_instruction_and_examples( 59 | self, instruction: str = "", examples: str = "" 60 | ) -> None: 61 | """Set the instruction and examples directly.""" 62 | self._instruction = instruction 63 | self._examples = examples 64 | -------------------------------------------------------------------------------- /prompt2model/prompt_parser/mock.py: -------------------------------------------------------------------------------- 1 | """An interface for prompt parsing.""" 2 | 3 | from prompt2model.prompt_parser.base import PromptSpec, TaskType 4 | 5 | 6 | class MockPromptSpec(PromptSpec): 7 | """Mock the bebavior of PromptSpec.""" 8 | 9 | def __init__( 10 | self, task_type: TaskType, instruction: str = None, examples: str = None 11 | ): 12 | """Mock the elements of PromptSpec.""" 13 | self.task_type = task_type 14 | if instruction is None: 15 | self._instruction = ( 16 | "Give me some translation from Chinese to English." 17 | " Input Chinese and output English." 18 | ) 19 | else: 20 | self._instruction = instruction 21 | if examples is None: 22 | self._examples = ( 23 | "input: '人生苦短,我用 Python', output: 'Life is short, I use Python. '" 24 | "input: '明天是周末', output: 'Tomorrow is weekend.'" 25 | ) 26 | else: 27 | self._examples = examples 28 | 29 | def parse_from_prompt(self, prompt: str) -> None: 30 | """Don't parse anything.""" 31 | self._instruction = prompt 32 | return None 33 | -------------------------------------------------------------------------------- /prompt2model/prompt_parser/readme.md: -------------------------------------------------------------------------------- 1 | # Prompt Parser 2 | 3 | ## Overview 4 | 5 | - `PromptSpec`: Interface for parsing prompts into instructions and 6 | examples. 7 | - `TaskType`: Enum for classifying NLP tasks like text generation, 8 | classification, etc. 9 | - `PromptBasedInstructionParser`: A `PromptSpec` subclass that uses GPT-3.5 10 | API for parsing. 11 | 12 | ## Getting Started 13 | 14 | - Import Modules: 15 | 16 | ```python 17 | from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType 18 | ``` 19 | 20 | - Setup API Key and Initialize Parser. For instance, if using OpenAI: 21 | 22 | ```bash 23 | export OPENAI_API_KEY="" 24 | ``` 25 | 26 | And then initialize the Parser: 27 | 28 | ```python 29 | task_type = TaskType. 30 | prompt_spec = PromptBasedInstructionParser(task_type) 31 | ``` 32 | 33 | ### Parse the Prompt 34 | 35 | ```python 36 | prompt = "" 37 | prompt_spec.parse_from_prompt(prompt) 38 | ``` 39 | 40 | ### Access Parsed Fields 41 | 42 | ```python 43 | instruction = prompt_spec.get_instruction # Retrieves parsed instruction. 44 | demonstrations = prompt_spec.get_examples # Retrieves parsed examples. 45 | ``` 46 | 47 | ### Mock 48 | 49 | If you want to mock a `PromptSpec` object without parsing from a prompt, 50 | you can use the `MockPromptSpec` class. 51 | 52 | ```python 53 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) 54 | instruction = """...""" # A string indicating the task description. 55 | examples = """...""" # A string indicating the examples. 56 | prompt_spec._instruction = prompt 57 | prompt_spec._examples = examples 58 | ``` 59 | -------------------------------------------------------------------------------- /prompt2model/run_locally.py: -------------------------------------------------------------------------------- 1 | """A script to run the prompt2model pipeline locally.""" 2 | from __future__ import annotations 3 | 4 | import argparse 5 | 6 | from prompt2model.dataset_generator import DatasetSplit, MockDatasetGenerator 7 | from prompt2model.dataset_processor import MockProcessor 8 | from prompt2model.dataset_retriever import MockRetriever 9 | from prompt2model.demo_creator import mock_gradio_create 10 | from prompt2model.model_evaluator import MockEvaluator 11 | from prompt2model.model_executor import MockModelExecutor 12 | from prompt2model.model_retriever import MockModelRetriever 13 | from prompt2model.model_trainer import MockTrainer 14 | from prompt2model.param_selector import MockParamSelector 15 | from prompt2model.prompt_parser import MockPromptSpec, PromptSpec, TaskType 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--prompt", 20 | type=str, 21 | nargs="+", 22 | required=True, 23 | help="Prompt (with optional few-shot examples) for language model", 24 | ) 25 | parser.add_argument( 26 | "--metrics-output-path", 27 | type=str, 28 | help="Path to JSON file where we store model metrics", 29 | default="/tmp/metrics.json", 30 | ) 31 | 32 | 33 | def process_input_prompt(prompt_tokens: list[str]) -> PromptSpec: 34 | """Preprocess the input prompt given by the user and parse. 35 | 36 | Args: 37 | prompt_tokens: Tokens in the prompt. 38 | 39 | Returns: 40 | A PromptSpec parsed from the processed prompt tokens. 41 | 42 | """ 43 | prompt_str = " ".join(prompt_tokens).strip() 44 | start_quotations_present = False 45 | end_quotations_present = False 46 | quotation_marks = ['"', "“", "‟", "”"] 47 | for start_quote in quotation_marks: 48 | if prompt_str.startswith(start_quote): 49 | start_quotations_present = True 50 | break 51 | for end_quote in quotation_marks: 52 | if prompt_str.endswith(end_quote): 53 | end_quotations_present = True 54 | break 55 | if start_quotations_present and end_quotations_present: 56 | prompt_str = prompt_str[1:-1] 57 | 58 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) 59 | prompt_spec.parse_from_prompt(prompt_str) 60 | return prompt_spec 61 | 62 | 63 | def run_skeleton(prompt_tokens: list[str], metrics_output_path: str) -> None: 64 | """Run the prompt2model pipeline locally using base/stub components.""" 65 | prompt_spec = process_input_prompt(prompt_tokens) 66 | 67 | # Retrieve and generate datasets 68 | retriever = MockRetriever() 69 | retrieved_dataset_dicts = retriever.retrieve_dataset_dict(prompt_spec) 70 | 71 | generator = MockDatasetGenerator() 72 | expected_num_examples = { 73 | DatasetSplit.TRAIN: 40, 74 | DatasetSplit.VAL: 5, 75 | DatasetSplit.TEST: 5, 76 | } 77 | generated_dataset_dicts = generator.generate_dataset_dict( 78 | prompt_spec, expected_num_examples 79 | ) 80 | 81 | processor = MockProcessor(has_encoder=True, eos_token="") 82 | retrieved_dataset_dicts, generated_dataset_dicts = processor.process_dataset_dict( 83 | instruction="", dataset_dicts=[retrieved_dataset_dicts, generated_dataset_dicts] 84 | ) 85 | 86 | retrieved_training = [ 87 | dataset_dict["train"] for dataset_dict in retrieved_dataset_dicts 88 | ] 89 | 90 | generated_training = generated_dataset_dicts[DatasetSplit.TRAIN.value] 91 | validation = generated_dataset_dicts[DatasetSplit.VAL.value] 92 | testing = generated_dataset_dicts[DatasetSplit.TEST.value] 93 | all_training = retrieved_training + [generated_training] 94 | 95 | model_retriever = MockModelRetriever("cardiffnlp/twitter-roberta-base-sentiment") 96 | retrieved_model_name = model_retriever.retrieve(prompt_spec) 97 | 98 | trainer = MockTrainer(retrieved_model_name[0]) 99 | selector = MockParamSelector(trainer) 100 | model, tokenizer = selector.select_from_hyperparameters( 101 | all_training, validation, {} 102 | ) 103 | 104 | model_executor = MockModelExecutor(model, tokenizer) 105 | predictions = model_executor.make_prediction(testing, "input_col") 106 | 107 | evaluator = MockEvaluator() 108 | metrics_dict = evaluator.evaluate_model( 109 | testing, "output_col", predictions, "input_col", [] 110 | ) 111 | evaluator.write_metrics(metrics_dict, metrics_output_path) 112 | mock_gradio_create(model_executor, prompt_spec) 113 | 114 | 115 | if __name__ == "__main__": 116 | args = parser.parse_args() 117 | run_skeleton(args.prompt, args.metrics_output_path) 118 | -------------------------------------------------------------------------------- /prompt2model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Import utility functions.""" 2 | from prompt2model.utils.api_tools import ( 3 | API_ERRORS, 4 | APIAgent, 5 | count_tokens_from_string, 6 | handle_api_error, 7 | ) 8 | from prompt2model.utils.logging_utils import get_formatted_logger 9 | from prompt2model.utils.rng import seed_generator 10 | from prompt2model.utils.tevatron_utils import encode_text, retrieve_objects 11 | 12 | __all__ = ( # noqa: F401 13 | "APIAgent", 14 | "encode_text", 15 | "handle_api_error", 16 | "API_ERRORS", 17 | "retrieve_objects", 18 | "seed_generator", 19 | "count_tokens_from_string", 20 | "get_formatted_logger", 21 | ) 22 | -------------------------------------------------------------------------------- /prompt2model/utils/config.py: -------------------------------------------------------------------------------- 1 | """Place to store all the default configurations.""" 2 | MAX_SUPPORTED_BATCH_SIZE = 4 3 | 4 | DEFAULT_HYPERPARAMETERS_SPACE = { 5 | "min_num_train_epochs": 5, 6 | "max_num_train_epochs": 15, 7 | "save_strategy": ["no"], 8 | "evaluation_strategy": ["no"], 9 | "per_device_train_batch_size": MAX_SUPPORTED_BATCH_SIZE, 10 | "min_weight_decay": 4e-5, 11 | "max_weight_decay": 1e-1, 12 | "min_learning_rate": 4e-5, 13 | "max_learning_rate": 1e-1, 14 | } 15 | -------------------------------------------------------------------------------- /prompt2model/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | """Util functions for datasets.""" 2 | 3 | 4 | import datasets 5 | import requests 6 | 7 | from prompt2model.utils.logging_utils import get_formatted_logger 8 | 9 | logger = get_formatted_logger("dataset_utils") 10 | 11 | 12 | def query(API_URL): 13 | """Returns a response json for a URL.""" 14 | try: 15 | response = requests.get(API_URL) 16 | if response.status_code == 200: 17 | return response.json() 18 | else: 19 | logger.error(f"Error occurred in fetching size: {response.status_code}") 20 | except requests.exceptions.RequestException as e: 21 | logger.error("Error occurred in making the request: " + str(e)) 22 | 23 | return {} 24 | 25 | 26 | def get_dataset_size(dataset_name): 27 | """Fetches dataset size for a dataset in MB from hugging face API.""" 28 | API_URL = f"https://datasets-server.huggingface.co/size?dataset={dataset_name}" 29 | data = query(API_URL) 30 | size_dict = data.get("size", {}) 31 | return ( 32 | "NA" 33 | if size_dict == {} 34 | else "{:.2f}".format(size_dict["dataset"]["num_bytes_memory"] / 1024 / 1024) 35 | ) 36 | 37 | 38 | def make_combined_datasets( 39 | dataset_list: list[datasets.Dataset], dataset_type: str = "input_output" 40 | ) -> datasets.Dataset: 41 | """Combine multiple datasets into one. 42 | 43 | Args: 44 | dataset_list: List of datasets to combine. 45 | dataset_type: Type of dataset to combine. Can be "text" or "input_output". 46 | "text" is for combining datasets with a single column "text". 47 | "input_output" is for combining datasets with 2 columns "input_col" 48 | and "output_col". 49 | 50 | Returns: 51 | A combined dataset. 52 | Singe column "text" if dataset_type is "text". 53 | Two columns "input_col" and "output_col" if dataset_type is "input_output". 54 | ValueError if dataset_type is not "text" or "input_output". 55 | """ 56 | if dataset_type == "text": 57 | text_col = [] 58 | for dataset in dataset_list: 59 | text_col.extend(dataset["text"]) 60 | return datasets.Dataset.from_dict({"text": text_col}) 61 | elif dataset_type == "input_output": 62 | input_col = [] 63 | output_col = [] 64 | for dataset in dataset_list: 65 | input_col.extend(dataset["input_col"]) 66 | output_col.extend(dataset["output_col"]) 67 | 68 | dataset = datasets.Dataset.from_dict( 69 | {"input_col": input_col, "output_col": output_col} 70 | ) 71 | return dataset 72 | else: 73 | raise ValueError( 74 | f"dataset_type can be either 'text' or 'input_output' but got {dataset_type}" # noqa E501 75 | ) 76 | 77 | 78 | def format_train_data(train_dataset: datasets.Dataset): 79 | """Formats the train dataset for training.""" 80 | final_texts = [] 81 | for row in train_dataset: 82 | final_texts.append(f"{row['input_col'].strip()} {row['output_col'].strip()}") 83 | return datasets.Dataset.from_dict({"text": final_texts}) 84 | -------------------------------------------------------------------------------- /prompt2model/utils/dataset_utils_test.py: -------------------------------------------------------------------------------- 1 | """Testing dataset utility functions.""" 2 | from unittest.mock import patch 3 | 4 | from prompt2model.utils import dataset_utils 5 | 6 | 7 | @patch("prompt2model.utils.dataset_utils.query") 8 | def test_get_dataset_size(mock_request): 9 | """Test function for get_dataset_size.""" 10 | mock_request.return_value = { 11 | "size": { 12 | "dataset": { 13 | "dataset": "rotten_tomatoes", 14 | "num_bytes_original_files": 487770, 15 | "num_bytes_parquet_files": 881052, 16 | "num_bytes_memory": 1345449, 17 | "num_rows": 10662, 18 | }, 19 | "configs": [ 20 | { 21 | "dataset": "rotten_tomatoes", 22 | "config": "default", 23 | "num_bytes_original_files": 487770, 24 | "num_bytes_parquet_files": 881052, 25 | "num_bytes_memory": 1345449, 26 | "num_rows": 10662, 27 | "num_columns": 2, 28 | } 29 | ], 30 | "splits": [ 31 | { 32 | "dataset": "rotten_tomatoes", 33 | "config": "default", 34 | "split": "train", 35 | "num_bytes_parquet_files": 698845, 36 | "num_bytes_memory": 1074806, 37 | "num_rows": 8530, 38 | "num_columns": 2, 39 | }, 40 | { 41 | "dataset": "rotten_tomatoes", 42 | "config": "default", 43 | "split": "validation", 44 | "num_bytes_parquet_files": 90001, 45 | "num_bytes_memory": 134675, 46 | "num_rows": 1066, 47 | "num_columns": 2, 48 | }, 49 | { 50 | "dataset": "rotten_tomatoes", 51 | "config": "default", 52 | "split": "test", 53 | "num_bytes_parquet_files": 92206, 54 | "num_bytes_memory": 135968, 55 | "num_rows": 1066, 56 | "num_columns": 2, 57 | }, 58 | ], 59 | }, 60 | "pending": [], 61 | "failed": [], 62 | "partial": False, 63 | } 64 | assert dataset_utils.get_dataset_size("rotten_tomatoes") == "1.28" 65 | -------------------------------------------------------------------------------- /prompt2model/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for creating formatted logger.""" 2 | 3 | import logging 4 | 5 | 6 | def get_formatted_logger(logger_name: str): 7 | """Create a formatted logger. 8 | 9 | Args: 10 | logger_name: The name of the logger, usually the name 11 | of the component that uses the logger. 12 | 13 | Returns: 14 | A logger object. 15 | """ 16 | logger = logging.getLogger(logger_name) 17 | # Check if the logger already has a StreamHandler to prevent adding another one. 18 | if not any( 19 | isinstance(handler, logging.StreamHandler) for handler in logger.handlers 20 | ): 21 | ch = logging.StreamHandler() 22 | formatter = logging.Formatter( 23 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 24 | ) 25 | ch.setFormatter(formatter) 26 | logger.addHandler(ch) 27 | return logger 28 | -------------------------------------------------------------------------------- /prompt2model/utils/parse_responses.py: -------------------------------------------------------------------------------- 1 | """Utility file for parsing OpenAI json responses.""" 2 | from __future__ import annotations 3 | 4 | import json 5 | import re 6 | from typing import Any 7 | 8 | import openai 9 | 10 | from prompt2model.utils import api_tools, get_formatted_logger 11 | from prompt2model.utils.api_tools import API_ERRORS, handle_api_error 12 | 13 | logger = get_formatted_logger("ParseJsonResponses") 14 | 15 | 16 | def find_and_parse_json( 17 | response: openai.Completion, required_keys: list, optional_keys: list = [] 18 | ) -> dict | None: 19 | """Parse stuctured fields from the API response. 20 | 21 | In case there are multiple JSON objects in the response, take the final one. 22 | 23 | Args: 24 | response: API response. 25 | required_keys: Required keys from the response 26 | optional_keys: Optional keys from the response 27 | 28 | Returns: 29 | If the API response is a valid JSON object and contains the 30 | required and optional keys then returns the 31 | final response as a Dictionary 32 | Else returns None. 33 | """ 34 | if not isinstance(response, str) and hasattr(response, "choices"): 35 | response = response.choices[0]["message"]["content"] 36 | correct_json = find_rightmost_brackets(response) 37 | 38 | if correct_json is None: 39 | logger.warning("No valid JSON found in the response.") 40 | return None 41 | 42 | try: 43 | response_json = json.loads(correct_json, strict=False) 44 | except json.decoder.JSONDecodeError: 45 | logger.warning(f"API response was not a valid JSON: {correct_json}") 46 | return None 47 | 48 | missing_keys = [key for key in required_keys if key not in response_json] 49 | if len(missing_keys) != 0: 50 | logger.warning(f'API response must contain {", ".join(required_keys)} keys') 51 | return None 52 | 53 | final_response = {} 54 | for key in required_keys + optional_keys: 55 | if key not in response_json: 56 | # This is an optional key, so exclude it from the final response. 57 | continue 58 | if type(response_json[key]) == str: 59 | final_response[key] = response_json[key].strip() 60 | else: 61 | final_response[key] = response_json[key] 62 | return final_response 63 | 64 | 65 | def find_rightmost_brackets(text: str) -> str | None: 66 | """Find the rightmost complete set of brackets in a string.""" 67 | stack = [] 68 | for i, char in enumerate(reversed(text)): 69 | if char == "}": 70 | stack.append(len(text) - i - 1) 71 | elif char == "{" and stack: 72 | start = len(text) - i - 1 73 | end = stack.pop() 74 | if not stack: # Found the rightmost complete set 75 | return text[start : end + 1] 76 | return None 77 | 78 | 79 | def parse_dataset_config_responses(response: openai.ChatCompletion) -> dict: 80 | """Parse the response to extract relevant information from dataset/configuration. 81 | 82 | LLMs can return the dataset configuration in different formats - 83 | usually either between ** ** or as a sentence. 84 | 85 | Args: 86 | response: The response containing the dataset configuration. 87 | 88 | Returns: 89 | The extracted relevant information from the dataset configuration. 90 | """ 91 | if not isinstance(response, str) and hasattr(response, "choices"): 92 | response_str = response.choices[0]["message"]["content"] 93 | else: 94 | response_str = response 95 | 96 | pattern = r"\*\*(.*?)\*\*" 97 | 98 | match = re.search(pattern, response_str) 99 | dataset_config = "" 100 | if match: 101 | dataset_config = match.group(1) 102 | elif len(response_str.split()) >= 1: 103 | dataset_config = response_str.split()[-1].replace(".", "") 104 | return {"name": dataset_config} 105 | 106 | 107 | def parse_prompt_to_fields( 108 | prompt: str, 109 | required_keys: list = [], 110 | optional_keys: list = [], 111 | max_api_calls: int = 5, 112 | module_name: str = "col_selection", 113 | ) -> dict[str, Any]: 114 | """Parse prompt into specific fields, and return to the calling function. 115 | 116 | This function calls the required api, has the logic for the retrying, 117 | passes the response to the parsing function, and return the 118 | response back or throws an error 119 | 120 | Args: 121 | prompt: User prompt into specific fields 122 | required_keys: Fields that need to be present in the response 123 | optional_keys: Field that may/may not be present in the response 124 | max_api_calls: Max number of retries, defaults to 5 to avoid 125 | being stuck in an infinite loop 126 | module_name: The module this is to be used for. Currently supports 127 | rerank and col_selection 128 | 129 | Returns: 130 | Parsed Response as a dictionary. 131 | 132 | Raises: 133 | ValueError: If max_api_calls is not greater than 0. 134 | RuntimeError: If the maximum number of API calls is reached. 135 | 136 | """ 137 | chat_api = api_tools.default_api_agent 138 | if max_api_calls <= 0: 139 | raise ValueError("max_api_calls must be > 0.") 140 | 141 | api_call_counter = 0 142 | last_error = None 143 | while True: 144 | api_call_counter += 1 145 | try: 146 | response: openai.ChatCompletion | Exception = ( 147 | chat_api.generate_one_completion( 148 | prompt, 149 | temperature=0.01, 150 | presence_penalty=0, 151 | frequency_penalty=0, 152 | ) 153 | ) 154 | extraction: dict[str, Any] | None = None 155 | if module_name == "col_selection": 156 | extraction = find_and_parse_json(response, required_keys, optional_keys) 157 | 158 | elif module_name == "rerank": 159 | extraction = parse_dataset_config_responses(response) 160 | if extraction is not None: 161 | return extraction 162 | except API_ERRORS as e: 163 | last_error = e 164 | handle_api_error(e, backoff_duration=2**api_call_counter) 165 | 166 | if api_call_counter >= max_api_calls: 167 | # In case we reach maximum number of API calls, we raise an error. 168 | logger.error("Maximum number of API calls reached.") 169 | raise RuntimeError("Maximum number of API calls reached.") from last_error 170 | 171 | 172 | def make_single_api_request(prompt: str, max_api_calls: int = 10) -> str: 173 | """Prompts an LLM using the APIAgent, and returns the response. 174 | 175 | This function calls the required api, has the logic for retrying, 176 | returns the response back or throws an error 177 | Args: 178 | prompt: User prompt into specific fields 179 | max_api_calls: Max number of retries, defaults to 5 to avoid 180 | being stuck in an infinite loop 181 | Returns: 182 | Response text or throws error 183 | """ 184 | chat_api = api_tools.default_api_agent 185 | if max_api_calls <= 0: 186 | raise ValueError("max_api_calls must be > 0.") 187 | 188 | api_call_counter = 0 189 | last_error = None 190 | while True: 191 | api_call_counter += 1 192 | try: 193 | response: openai.ChatCompletion = chat_api.generate_one_completion( 194 | prompt=prompt, temperature=0.01, presence_penalty=0, frequency_penalty=0 195 | ) 196 | if response is not None: 197 | return response.choices[0]["message"]["content"] 198 | 199 | except API_ERRORS as e: 200 | last_error = e 201 | handle_api_error(e, backoff_duration=2**api_call_counter) 202 | 203 | if api_call_counter >= max_api_calls: 204 | # In case we reach maximum number of API calls, we raise an error. 205 | logger.error("Maximum number of API calls reached.") 206 | raise RuntimeError("Maximum number of API calls reached.") from last_error 207 | -------------------------------------------------------------------------------- /prompt2model/utils/retrieve_model_info.py: -------------------------------------------------------------------------------- 1 | """Retrieve HuggingFace model's size, description, and downloads.""" 2 | import json 3 | import os 4 | import re 5 | import subprocess 6 | from pathlib import Path 7 | 8 | from huggingface_hub import HfApi 9 | 10 | 11 | def main(pretrained_model_name: str, cache_dir: str = None) -> None: 12 | """Downloads and caches a Hugging Face model's metadata. 13 | 14 | Args: 15 | pretrained_model_name: HuggingFace pretrained_model_name. 16 | cache_dir: A directory to cache the metadata. 17 | """ 18 | if cache_dir is None: 19 | cache_dir = "model_info" 20 | cache_path = Path.cwd() / cache_dir 21 | cache_path.mkdir(parents=True, exist_ok=True) 22 | 23 | if len(pretrained_model_name.split("/")) == 2: 24 | _, model_name = pretrained_model_name.split("/") 25 | else: 26 | model_name = pretrained_model_name 27 | 28 | subprocess.run( 29 | ["git", "clone", f"https://huggingface.co/{pretrained_model_name}"], 30 | env=dict(os.environ, GIT_LFS_SKIP_SMUDGE="1"), 31 | stdout=subprocess.DEVNULL, 32 | stderr=subprocess.DEVNULL, 33 | ) 34 | 35 | model_bin_file = Path(f"{model_name}/pytorch_model.bin") 36 | 37 | try: 38 | if model_bin_file.exists(): 39 | with open(model_bin_file, "r") as file: 40 | content = file.read() 41 | size = re.search(r"size (\d+)", content).group(1) # type: ignore 42 | print(size) 43 | else: 44 | model_bin_file = Path(f"{model_name}/pytorch_model.bin.index.json") 45 | with open(model_bin_file, "r") as file: 46 | content = json.loads(file.read()) # type: ignore 47 | size = content["metadata"]["total_size"] # type: ignore 48 | print(size) 49 | except ( 50 | FileNotFoundError, 51 | PermissionError, 52 | IOError, 53 | json.decoder.JSONDecodeError, 54 | ) as e: 55 | raise Exception(f"Failed to read {model_name} in {model_bin_file}: {e}") 56 | 57 | with open(f"{model_name}/README.md", "r", encoding="utf-8") as f: 58 | readme_content = f.read() 59 | print(readme_content) 60 | 61 | api = HfApi() 62 | model_meta = api.model_info(pretrained_model_name) 63 | downloads = model_meta.downloads 64 | print(pretrained_model_name, downloads) 65 | 66 | model_info = { 67 | "pretrained_model_name": pretrained_model_name, 68 | "description": readme_content, 69 | "size_bytes": size, 70 | "downloads": downloads, 71 | } 72 | model_info_path = Path(f"{cache_dir}/{model_name}.json") 73 | model_info_path.touch() 74 | with open(model_info_path, "w") as file: 75 | file.write(json.dumps(model_info)) 76 | 77 | # The model must exist because it can be found by 78 | # `model_meta = api.model_info(pretrained_model_name)` 79 | subprocess.run(["rm", "-rf", model_name]) 80 | 81 | 82 | if __name__ == "__main__": 83 | main("facebook/roscoe-512-roberta-base") 84 | main("gpt2") 85 | -------------------------------------------------------------------------------- /prompt2model/utils/rng.py: -------------------------------------------------------------------------------- 1 | """Classes for setting random seeds.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | 8 | class SeedGenerator(ABC): 9 | """Select a good model from among a set of hyperparameter choices.""" 10 | 11 | @abstractmethod 12 | def get_seed(self) -> int: 13 | """Generate a random seed.""" 14 | 15 | 16 | class ConstantSeedGenerator(SeedGenerator): 17 | """A seed generator that always returns the same seed.""" 18 | 19 | def __init__(self, seed: int = 2023): 20 | """Initialize with a constant seed (by default, 2023).""" 21 | self.seed = seed 22 | 23 | def get_seed(self) -> int: 24 | """Return a constant random seed.""" 25 | return self.seed 26 | 27 | 28 | seed_generator = ConstantSeedGenerator() 29 | -------------------------------------------------------------------------------- /prompt2model/utils/tevatron_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Import Tevatron utility functions.""" 2 | from prompt2model.utils.tevatron_utils.encode import encode_text 3 | from prompt2model.utils.tevatron_utils.retrieve import retrieve_objects 4 | 5 | __all__ = ("encode_text", "retrieve_objects") 6 | -------------------------------------------------------------------------------- /prompt2model/utils/tevatron_utils/encode.py: -------------------------------------------------------------------------------- 1 | """Tools for encoding and serializing a search index with a contextual encoder.""" 2 | 3 | from __future__ import annotations # noqa FI58 4 | 5 | import json 6 | import os 7 | import pickle 8 | import tempfile 9 | from contextlib import nullcontext 10 | 11 | import numpy as np 12 | import torch 13 | from tevatron.arguments import DataArguments 14 | from tevatron.data import EncodeCollator, EncodeDataset 15 | from tevatron.datasets import HFCorpusDataset, HFQueryDataset 16 | from tevatron.modeling import DenseModelForInference 17 | from torch.utils.data import DataLoader 18 | from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase 19 | 20 | 21 | def load_tevatron_model( 22 | model_name_or_path: str, model_cache_dir: str | None = None 23 | ) -> tuple[DenseModelForInference, PreTrainedTokenizerBase]: 24 | """Load a Tevatron model from a model name/path. 25 | 26 | Args: 27 | model_name_or_path: The HuggingFace model name or path to the model. 28 | model_cache_dir: The directory to cache the model. 29 | 30 | Returns: 31 | A Tevatron dense retrieval model and its associated tokenizer. 32 | """ 33 | config = AutoConfig.from_pretrained( 34 | model_name_or_path, 35 | cache_dir=model_cache_dir, 36 | ) 37 | tokenizer = AutoTokenizer.from_pretrained( 38 | model_name_or_path, 39 | cache_dir=model_cache_dir, 40 | use_fast=False, 41 | ) 42 | model = DenseModelForInference.build( 43 | model_name_or_path=model_name_or_path, 44 | config=config, 45 | cache_dir=model_cache_dir, 46 | ) 47 | return model, tokenizer 48 | 49 | 50 | def encode_text( 51 | model_name_or_path: str, 52 | file_to_encode: str | None = None, 53 | text_to_encode: list[str] | str | None = None, 54 | encode_query: bool = False, 55 | encoding_file: str | None = None, 56 | max_len: int = 400, 57 | device: torch.device = torch.device("cpu"), 58 | dataloader_num_workers: int = 0, 59 | model_cache_dir: str | None = None, 60 | data_cache_dir: str = "~/.cache/huggingface/datasets", 61 | batch_size=8, 62 | fp16: bool = False, 63 | ) -> np.ndarray: 64 | """Encode a query or documents. 65 | 66 | This code is mostly duplicated from tevatron/driver/encode.py in the Tevatron 67 | repository. 68 | 69 | Args: 70 | model_name_or_path: The HuggingFace model name or path to the model. 71 | file_to_encode: JSON or JSONL file containing `"text"` fields to encode. 72 | text_to_encode: String or list of strings to encode. 73 | encode_query: Whether or not we are encoding a query or documents. 74 | encoding_file: If given, store the encoded data in this file. 75 | max_len: Truncate the input to this length (in tokens). 76 | device: Device that Torch will use to encode the text. 77 | dataloader_num_workers: Number of workers to use for the dataloader. 78 | model_cache_dir: The directory to cache the model. 79 | data_cache_dir: The directory to cache the tokenized dataset. 80 | batch_size: Batch size to use for encoding. 81 | fp16: Whether or not to run inference in fp16 for more-efficient encoding. 82 | 83 | Returns: 84 | A numpy array of shape `(expected_num_examples, embedding_dim)` containing text 85 | encoded by the specified model. 86 | """ 87 | model, tokenizer = load_tevatron_model(model_name_or_path, model_cache_dir) 88 | 89 | if file_to_encode is None and text_to_encode is None: 90 | raise ValueError("Must provide either a dataset file or text to encode.") 91 | elif file_to_encode is not None and text_to_encode is not None: 92 | raise ValueError("Provide either a dataset file or text to encode, not both.") 93 | 94 | with tempfile.TemporaryDirectory() as temp_dir: 95 | if text_to_encode is not None: 96 | if isinstance(text_to_encode, str): 97 | text_to_encode = [text_to_encode] 98 | with open( 99 | os.path.join(temp_dir, "text_to_encode.json"), "w" 100 | ) as temporary_file: 101 | text_rows = [ 102 | {"text_id": i, "text": text} 103 | for i, text in enumerate(text_to_encode) 104 | ] 105 | json.dump(text_rows, temporary_file) 106 | file_to_encode = temporary_file.name 107 | temporary_file.close() 108 | 109 | data_args = DataArguments( 110 | encoded_save_path=encoding_file, 111 | encode_in_path=file_to_encode, 112 | encode_is_qry=encode_query, 113 | data_cache_dir=data_cache_dir, 114 | ) 115 | if encode_query: 116 | data_args.q_max_len = max_len 117 | hf_dataset = HFQueryDataset( 118 | tokenizer=tokenizer, 119 | data_args=data_args, 120 | cache_dir=data_args.data_cache_dir or model_cache_dir, 121 | ) 122 | else: 123 | data_args.p_max_len = max_len 124 | hf_dataset = HFCorpusDataset( 125 | tokenizer=tokenizer, 126 | data_args=data_args, 127 | cache_dir=data_args.data_cache_dir or model_cache_dir, 128 | ) 129 | 130 | encode_dataset = EncodeDataset( 131 | hf_dataset.process(1, 0), tokenizer, max_len=max_len 132 | ) 133 | 134 | encode_loader = DataLoader( 135 | encode_dataset, 136 | batch_size=batch_size, 137 | collate_fn=EncodeCollator( 138 | tokenizer, max_length=max_len, padding="max_length" 139 | ), 140 | shuffle=False, 141 | drop_last=False, 142 | num_workers=dataloader_num_workers, 143 | ) 144 | encoded = [] 145 | lookup_indices = [] 146 | model = model.to(device) 147 | model.eval() 148 | 149 | for batch_ids, batch in encode_loader: 150 | lookup_indices.extend(batch_ids) 151 | with torch.cuda.amp.autocast() if fp16 else nullcontext(): 152 | with torch.no_grad(): 153 | for k, v in batch.items(): 154 | batch[k] = v.to(device) 155 | if data_args.encode_is_qry: 156 | model_output = model(query=batch) 157 | encoded.append(model_output.q_reps.cpu().detach().numpy()) 158 | else: 159 | model_output = model(passage=batch) 160 | encoded.append(model_output.p_reps.cpu().detach().numpy()) 161 | 162 | encoded = np.concatenate(encoded) 163 | 164 | if encoding_file: 165 | with open(encoding_file, "wb") as f: 166 | pickle.dump((encoded, lookup_indices), f) 167 | 168 | return encoded 169 | -------------------------------------------------------------------------------- /prompt2model/utils/tevatron_utils/retrieve.py: -------------------------------------------------------------------------------- 1 | """Tools for doing efficient similarity search via the Tevatron/faiss libraries.""" 2 | from __future__ import annotations 3 | 4 | import pickle 5 | 6 | import numpy as np 7 | from tevatron.faiss_retriever import BaseFaissIPRetriever 8 | 9 | 10 | def retrieve_objects( 11 | query_vector: np.ndarray, 12 | encoded_datasets_path: str, 13 | document_names: list[str], 14 | depth: int, 15 | ) -> list[tuple[str, float]]: 16 | """Return a ranked list of object indices and their scores. 17 | 18 | Args: 19 | query vector: Vector representation of query. 20 | encoded_datasets_path: Path to file containing encoded dataset index. 21 | depth: Number of documents to return. 22 | 23 | Returns: 24 | Ranked list of object names and their inner product similarity to the query. 25 | """ 26 | if query_vector.shape[0] != 1: 27 | raise ValueError("Only a single query vector is expected.") 28 | if len(query_vector.shape) != 2: 29 | raise ValueError("Query vector must be 1-D.") 30 | 31 | with open(encoded_datasets_path, "rb") as f: 32 | passage_reps, passage_lookup = pickle.load(f) 33 | retriever = BaseFaissIPRetriever(passage_reps) 34 | retriever.add(passage_reps) 35 | 36 | all_scores, all_indices = retriever.search(query_vector, depth) 37 | if not (len(all_scores) == len(all_indices) == 1): 38 | raise ValueError("Only one query's ranking should be returned.") 39 | 40 | psg_scores = all_scores[0] 41 | ranked_document_names = [document_names[passage_lookup[x]] for x in all_indices[0]] 42 | score_tuples = list(zip(ranked_document_names, psg_scores)) 43 | return score_tuples 44 | -------------------------------------------------------------------------------- /prompt2model/version.py: -------------------------------------------------------------------------------- 1 | """A version template file. 2 | 3 | This doesn't actually specify the version, which is specified. 4 | Through using tags on the repository. 5 | """ 6 | 7 | VERSION = "0.0.0a0" 8 | -------------------------------------------------------------------------------- /prompt_examples.md: -------------------------------------------------------------------------------- 1 | # Tips and Examples to Write a Good Prompt 2 | 3 | ## How to Write a Good Prompt 4 | 5 | A good prompt can make the generated dataset 6 | follow exactly the format of demonstrations. 7 | It contains the instruction and few-shot examples. 8 | 9 | The instruction should contain the following: 10 | 11 | 1. The exact format description for the input 12 | and output, i.e., a string, a dictionary, or whatever. 13 | 2. The exact contents of each part of the 14 | input and their relationship as possible as you can. 15 | 3. The range of possible input. For example, 16 | "And the question can range from Math, Cultural, 17 | Social, Geometry, Biology, History, Sports, Technology, 18 | Science, and so on." 19 | 20 | The few-shot examples should contain the following: 21 | 22 | 1. Use `=` rather than other ambiguous symbols like `:`. 23 | 2. Avoid unnecessary line breaks at the beginning. 24 | For example, `input=""` is better than breaking 25 | the line after `=`. 26 | 3. Use `input` rather than `Input`, `output` is 27 | preferable likewise. 28 | 4. Wrap the `input` and `output` into a string with `“”`. 29 | 30 | Though the examples are optional, we strongly 31 | suggest including them to guide the format and 32 | content for the generator. 33 | 34 | Also, we recommend providing several precise examples 35 | in the specified format and inquiring with ChatGPT 36 | about the format and scope of your examples. 37 | 38 | ## Examples of Good Prompts 39 | 40 | Here are some examples of good prompts: 41 | 42 | ### Question Answering 43 | 44 | ```text 45 | """Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on. 46 | 47 | Here are examples with input questions and context passages, along with their expected outputs: 48 | 49 | input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50." 50 | output="Santa Clara" 51 | 52 | input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)." 53 | output="Vistula River" 54 | 55 | input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries." 56 | output="Europe" 57 | """ 58 | ``` 59 | 60 | ### Temporal Expression Normalization 61 | 62 | ```text 63 | """Temporal date expressions are commonly used to refer to specific time periods. Your task is to identify these temporal date expressions and provide the exact dates they refer to. 64 | 65 | For this task, the input is a string containing two specific elements: a posted date in the format "[Posted: YYYY-MM-DD]" and a sentence or statement that contains various temporal date references (e.g., early December, the end of the year, today, August, last Christmas, next Month, etc). 66 | 67 | Your program should output a string that maps the time period references mentioned in the input to their corresponding dates, following these strict rules: 68 | 69 | 1. If temporal date references are found, the output should use either "YYYY-MM-DD", "YYYY-MM", or "YYYY" to represent the exact date. 70 | - If multiple time period references are found, separate them using '|'. 71 | 2. If no temporal date reference is found or the referred date is ambiguous, the output should just be 'N/A', i.e., output="N/A". 72 | 73 | Here are some examples: 74 | 75 | input="[Posted: 1998-09-07] Tourism industry revenues reportedly dropped to $300 million last year, down from $450 million the year before." 76 | output="last year == 1997" 77 | 78 | input="[Posted: 2013-09-27] Eat! @mepangilinan" 79 | output="N/A" 80 | 81 | input="[Posted: 1989-10-30] Rated single-B-1 by Moody's Investors Service Inc. and single-B-plus by Standard amp Poor's Corp., the issue will be sold through underwriters led by Goldman, Sachs amp Co. Hertz Corp. -- $100 million of senior notes due Nov. 1, 2009, priced at par to yield 9%." 82 | output="Nov. 1, 2009 == 2009-11-01" 83 | 84 | input="[Posted: 2014-07-11] So out of place with this early transfer business." 85 | output="N/A" 86 | 87 | input="[Posted: 2013-10-25] Quote of the Day: '#Yoga is what you learn on your way down!" 88 | output="the Day == 2013-10-25" 89 | 90 | input="[Posted: 2021-06-15] Google plans to develop PALM 2 model in the first quarter of next year." 91 | output="N/A" 92 | 93 | input="[Posted: 2013-03-22] We will release a new github repository in the next three months." 94 | output="the next three month == 2013-04" 95 | 96 | input="[Posted: 2022-05-17] The company's fiscal year starts on July 1st and ends on June 30th." 97 | output="July 1st == 2022-07-01 | June 30th == 2022-06-30" 98 | 99 | input="[Posted: 2013-03-22] This flu season started in early December, a month earlier than usual, and peaked by the end of year." 100 | output="N/A" 101 | 102 | input="[Posted: 1989-10-30] The issue, which is puttable back to the company in 1999, was priced at a spread of 110 basis points above the Treasury's 10-year note." 103 | output="1999 == 1999" 104 | 105 | input="[Posted: 2022-04-15] The company announced that they will release their new product at the end of next month." 106 | output="the end of next month == 2022-05-31" 107 | 108 | input="[Posted: 2022-03-15] The teacher is going to release a new assignment in a few days." 109 | output="N/A" 110 | """ 111 | ``` 112 | 113 | ### Japanese-to-Python Generation 114 | 115 | ```text 116 | """Pythonで1行のコードを生成し、StackOverflowの日本語の質問を解決してください。コメントや式は含めないでください。インポート文も不要です。 117 | 118 | このタスクでは、入力は日本語のテキストで、変数名や操作が記述されています。出力は、そのタスクを達成するためのPythonの1行のコードです。コメントや式は含めないでください。インポート文も不要です。 119 | 120 | input="スペースで区切られた入力`stdin`を変数に格納して表示する" 121 | output="for line in stdin: a = line.rstrip().split(' ') print(a)" 122 | 123 | input="リスト`word_list'内に出現する単語を数える" 124 | output="Counter(word_list)" 125 | 126 | input="tweepyインスタンス`api`を使い、文字列`word`を含んだツイートを検索し、結果をリストとして得る" 127 | output="search = api.search(q=word)" 128 | 129 | input="データベースの設定を表示する" 130 | output="print(settings.DATABASES)" 131 | 132 | input="ネストされているリスト`li`を見やすく表示する" 133 | output="pprint.pprint(li)" 134 | 135 | input="HTMLファイル'test.html'を開き、テキストオブジェクト'text'をutf-8で保存する" 136 | output="f = open('test.html', 'w') f.write(text.encode('utf-8'))" 137 | """ 138 | ``` 139 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "hatchling", 4 | ] 5 | build-backend = "hatchling.build" 6 | 7 | [project] 8 | name = "prompt2model" 9 | authors = [ 10 | {name = "Vijay Viswanathan", email = "vijayv@andrew.cmu.edu"}, 11 | {name = "Chenyang Zhao", email = "zhaochen20@mails.tsinghua.edu.cn"}, 12 | ] 13 | description = "A library for distilling models from prompts." 14 | readme = "README.md" 15 | repository = "https://github.com/neulab/prompt2model" 16 | requires-python = ">=3.9" 17 | license = {file = "LICENSE"} 18 | classifiers = [ 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | ] 23 | dependencies = [ 24 | "transformers", 25 | "datasets", 26 | "pandas", 27 | "fastapi", 28 | "gradio==3.38.0", 29 | "torch", 30 | "pytest", 31 | "openai", 32 | "sentencepiece", 33 | "bert_score", 34 | "sacrebleu", 35 | "evaluate", 36 | "tevatron", 37 | "faiss-cpu", 38 | "mdtex2html", 39 | "scikit-learn", 40 | "retriv", 41 | "tiktoken", 42 | "aiolimiter", 43 | "pyfiglet", 44 | "termcolor", 45 | "psutil", 46 | "protobuf", 47 | "nest-asyncio", 48 | "litellm", 49 | "peft" 50 | ] 51 | 52 | dynamic = ["version"] 53 | 54 | [project.optional-dependencies] 55 | test = [ 56 | "pytest>=6.0.0", 57 | ] 58 | dev = [ 59 | "pytest", 60 | "pre-commit" 61 | ] 62 | 63 | [tool.hatch.build] 64 | include = [ 65 | "*.py", 66 | ] 67 | exclude = [ 68 | "*_test.py", 69 | "test_*.py", 70 | ] 71 | only-packages = true 72 | 73 | [tool.hatch.build.targets.wheel] 74 | packages = ["prompt2model"] 75 | 76 | [tool.hatch.version] 77 | path = "prompt2model/version.py" 78 | 79 | [tool.pytest.ini_options] 80 | pythonpath = ["prompt2model/test_helpers"] 81 | testpaths = ["prompt2model/tests"] 82 | -------------------------------------------------------------------------------- /scripts/dataset_index/preprocessing.py: -------------------------------------------------------------------------------- 1 | """Filtering out datasets before we do heavy processing on them.""" 2 | import argparse 3 | import json 4 | import os 5 | from typing import Any 6 | 7 | from huggingface_hub import list_datasets 8 | 9 | 10 | def parse_arguments(): 11 | """Parse command line arguments for the script. 12 | 13 | Returns: 14 | argparse.Namespace: An object with the following attributes: 15 | - unprocessed_datasets_file (str): Filename for unprocessed datasets. 16 | - preprocessed_datasets_file (str): Filename for preprocessed datasets. 17 | - min_words_in_desc (int): Minimum words in a dataset description. 18 | - min_downloads (int): Minimum downloads for a dataset. 19 | 20 | """ 21 | parser = argparse.ArgumentParser(description="Process dataset files.") 22 | parser.add_argument( 23 | "--unprocessed_datasets_file", 24 | type=str, 25 | default="unprocessed.json", 26 | help="File name for unprocessed datasets", 27 | ) 28 | parser.add_argument( 29 | "--preprocessed_datasets_file", 30 | type=str, 31 | default="preprocessed_datasets.json", 32 | help="File name for preprocessed datasets", 33 | ) 34 | parser.add_argument( 35 | "--min_words_in_desc", 36 | type=int, 37 | default=4, 38 | help="Minimum number of words in description", 39 | ) 40 | parser.add_argument( 41 | "--min_downloads", type=int, default=10, help="Minimum number of downloads" 42 | ) 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def load_datasets(file_path: str) -> list[dict[str, Any]]: 49 | """Load all huggingface datasets from a JSON file. 50 | 51 | Check if the unfiltered dataset file exists, if not generate it. 52 | Generating it is helpful for multiple iterations of preprocessing. 53 | 54 | Args: 55 | file_path: File path of unfiltered datasets. 56 | 57 | Returns: 58 | List of datasets from HuggingFace. This doesn't have configs or rows. 59 | """ 60 | if not os.path.exists(file_path): 61 | all_datasets = list(list_datasets()) 62 | with open(file_path, "w") as f: 63 | ds = json.dumps([ob.__dict__ for ob in all_datasets]) 64 | f.write(ds) 65 | return all_datasets 66 | else: 67 | with open(file_path, "r") as file: 68 | return json.load(file) 69 | 70 | 71 | def filter_datasets( 72 | datasets: list[dict[str, Any]], min_downloads, min_words_in_desc 73 | ) -> list[dict[str, Any]]: 74 | """Filter datasets based on specific criteria. 75 | 76 | Filter if description is None, if number of words in description is < 4, 77 | if number of downloads is less than a threshold, or if it is a duplicated 78 | dataset (if the description is the same). 79 | 80 | Args: 81 | datasets: List of datasets from HuggingFace. 82 | 83 | Returns: 84 | Datasets filtered based on above criteria. 85 | 86 | """ 87 | filtered_datasets = [] 88 | descr_none = descr_small = downloads_less = common_descr = 0 89 | unique_descriptions: set[str] = set() 90 | 91 | for dataset_info in datasets: 92 | description = dataset_info.get("description") 93 | 94 | if not description: 95 | descr_none += 1 96 | continue 97 | if len(description.split()) < min_words_in_desc: 98 | descr_small += 1 99 | continue 100 | if dataset_info.get("downloads", 0) < min_downloads: 101 | downloads_less += 1 102 | continue 103 | if description in unique_descriptions: 104 | common_descr += 1 105 | continue 106 | 107 | filtered_datasets.append(dataset_info) 108 | unique_descriptions.add(description) 109 | 110 | print(f"{descr_none=}, {descr_small=}, {downloads_less=}, {common_descr=}") 111 | 112 | return filtered_datasets 113 | 114 | 115 | def main(args): 116 | """Main function to load and filter datasets.""" 117 | datasets = load_datasets(args.unprocessed_datasets_file) 118 | filtered_datasets = filter_datasets( 119 | datasets=datasets, 120 | min_downloads=args.min_downloads, 121 | min_words_in_desc=args.min_words_in_desc, 122 | ) 123 | with open(args.preprocessed_datasets_file, "w") as f: 124 | json.dump(filtered_datasets, f) 125 | 126 | 127 | if __name__ == "__main__": 128 | args = parse_arguments() 129 | main(args) 130 | -------------------------------------------------------------------------------- /test_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | """Import mock classes used in unit tests.""" 2 | from test_helpers.mock_api import ( 3 | MockBatchDifferentCompletions, 4 | MockCompletion, 5 | UnknownGpt3Exception, 6 | mock_batch_api_response_identical_completions, 7 | ) 8 | from test_helpers.mock_retrieval import create_test_search_index 9 | from test_helpers.model_and_tokenizer import ( 10 | create_gpt2_model_and_tokenizer, 11 | create_t5_model_and_tokenizer, 12 | ) 13 | 14 | __all__ = ( 15 | "MockCompletion", 16 | "UnknownGpt3Exception", 17 | "MockBatchDifferentCompletions", 18 | "create_gpt2_model_and_tokenizer", 19 | "create_t5_model_and_tokenizer", 20 | "create_test_search_index", 21 | "mock_batch_api_response_identical_completions", 22 | "are_dataset_dicts_identical", 23 | "are_datasets_identical", 24 | "MockBatchResponseDifferentCompletions", 25 | ) 26 | -------------------------------------------------------------------------------- /test_helpers/dataset_index_tiny.json: -------------------------------------------------------------------------------- 1 | {"squad": {"name": "squad", "description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.", "evaluation_metadata": [{"config": "plain_text", "task": "question-answering", "task_id": "extractive_question_answering", "splits": {"train_split": "train", "eval_split": "validation"}, "col_mapping": {"question": "question", "context": "context", "answers": {"text": "text", "answer_start": "answer_start"}}, "metrics": [{"type": "squad", "name": "SQuAD"}]}]}, "trivia_qa": {"name": "trivia_qa", "description": "TriviaqQA is a reading comprehension dataset containing over 650K\nquestion-answer-evidence triples. TriviaqQA includes 95K question-answer\npairs authored by trivia enthusiasts and independently gathered evidence\ndocuments, six per question on average, that provide high quality distant\nsupervision for answering the questions.", "evaluation_metadata": {}}, "search_qa": {"name": "search_qa", "description": "We publicly release a new large-scale dataset, called SearchQA, for machine comprehension, or question-answering. Unlike recently released datasets, such as DeepMind\nCNN/DailyMail and SQuAD, the proposed SearchQA was constructed to reflect a full pipeline of general question-answering. That is, we start not from an existing article\nand generate a question-answer pair, but start from an existing question-answer pair, crawled from J! Archive, and augment it with text snippets retrieved by Google.\nFollowing this approach, we built SearchQA, which consists of more than 140k question-answer pairs with each pair having 49.6 snippets on average. Each question-answer-context\n tuple of the SearchQA comes with additional meta-data such as the snippet's URL, which we believe will be valuable resources for future research. We conduct human evaluation\n as well as test two baseline methods, one simple word selection and the other deep learning based, on the SearchQA. We show that there is a meaningful gap between the human\n and machine performances. This suggests that the proposed dataset could well serve as a benchmark for question-answering.", "evaluation_metadata": {}}} -------------------------------------------------------------------------------- /test_helpers/mock_api.py: -------------------------------------------------------------------------------- 1 | """Tools for mocking API responses (for testing purposes).""" 2 | 3 | from __future__ import annotations 4 | 5 | import openai 6 | 7 | from prompt2model.utils.api_tools import APIAgent 8 | 9 | 10 | class MockCompletion: 11 | """Mock completion object.""" 12 | 13 | def __init__(self, content: str | None = None, responses_per_request: int = 1): 14 | """Initialize a new instance of `MockCompletion` class. 15 | 16 | Args: 17 | content: The mocked content to be returned, i.e., 18 | `json.dumps({"comment": "This is a great movie!", 19 | "label": 1})`. 20 | responses_per_request: Number of responses 21 | for each request. 22 | """ 23 | # We generate 5 identical responses for each API call by default. 24 | if content is not None: 25 | # Mock a ChatCompletion with identical responses. 26 | self.choices = [{"message": {"content": content}}] * responses_per_request 27 | else: 28 | # Mock a ChatCompletion with different responses. 29 | # Only used in mock_batch_api_response_with_different_completion. 30 | # The choice will be replaced later in the function. 31 | self.choices = [] 32 | 33 | def __repr__(self): 34 | """Return a string representation. 35 | 36 | Returns: 37 | _string: A string representation of the object, including its choices. 38 | """ 39 | _string = f"" 40 | return _string 41 | 42 | 43 | class MockBatchDifferentCompletions: 44 | """Mock batch completion object.""" 45 | 46 | def __init__(self, length: int = 4) -> None: 47 | """Init a new instance of `MockBatchDifferentCompletions`. 48 | 49 | Args: 50 | length: Length of the batch completions. 51 | 52 | This class is designed to simulate the response of APIAgent and test the 53 | generation process of the PromptBasedDatasetGenerator with 54 | `filter_duplicated_examples` set to True in 55 | `dataset_generator_with_filter_test`. 56 | 57 | The class works in conjunction with PromptBasedDatasetGenerator with 58 | batch_size = 2, responses_per_request = 3, expected_num_examples 59 | = 5, and filter_duplicated_examples = True. 60 | 61 | Explanation of the generation process: 62 | 63 | In the first API call, the generator produces 2 * 3 = 6 responses. 64 | After filtering duplicates, the generated_dataset will be: 65 | 66 | Dataset.from_dict({ 67 | "input_col": ["1", "2"], 68 | "output_col": ["a", "a"], 69 | }) 70 | 71 | The second API call reduces batch_size to 1 and generates 3 more 72 | responses. After filtering duplicates, the generated_dataset will be: 73 | 74 | Dataset.from_dict({ 75 | "input_col": ["1", "2", "3"], 76 | "output_col": ["a", "a", "a"], 77 | }) 78 | 79 | The third API call again uses batch_size = 1 and generates another 80 | 3 responses. After filtering duplicates, the generated_dataset will be: 81 | 82 | Dataset.from_dict({ 83 | "input_col": ["1", "2", "3"], 84 | "output_col": ["b", "a", "a"], 85 | }) 86 | 87 | The fourth and API call also uses batch_size = 1 and generates the final 88 | 3 responses. After filtering duplicates, the generated_dataset will be: 89 | 90 | Dataset.from_dict({ 91 | "input_col": ["1", "2", "3", "4", "5"], 92 | "output_col": ["b", "a", "a", "c", "a"], 93 | }) 94 | 95 | The fivth and API call is specifically designed for 96 | testing generate dataset_dict. 97 | """ 98 | assert length == 4 or length == 5 99 | self.mock_completions: list[list[MockCompletion]] = [] 100 | self.current_index = 0 101 | mock_completion_1 = MockCompletion() 102 | mock_completion_1.choices = [ 103 | {"message": {"content": '{"input": "1", "output": "a"}'}}, 104 | {"message": {"content": '{"input": "1", "output": "b"}'}}, 105 | {"message": {"content": '{"input": "1", "output": "a"}'}}, 106 | ] 107 | mock_completion_2 = MockCompletion() 108 | mock_completion_2.choices = [ 109 | {"message": {"content": '{"input": "1", "output": "c"}'}}, 110 | {"message": {"content": '{"input": "2", "output": "a"}'}}, 111 | {"message": {"content": '{"input": "2", "output": "b"}'}}, 112 | ] 113 | self.mock_completions.append( 114 | [ 115 | mock_completion_1, 116 | mock_completion_2, 117 | ] 118 | ) 119 | mock_completion_3 = MockCompletion() 120 | mock_completion_3.choices = [ 121 | {"message": {"content": '{"input": "3", "output": "a"}'}}, 122 | {"message": {"content": '{"input": "3", "output": "a"}'}}, 123 | {"message": {"content": '{"input": "3", "output": "b"}'}}, 124 | ] 125 | self.mock_completions.append([mock_completion_3]) 126 | 127 | mock_completion_4 = MockCompletion() 128 | mock_completion_4.choices = [ 129 | {"message": {"content": '{"input": "1", "output": "b"}'}}, 130 | {"message": {"content": '{"input": "1", "output": "b"}'}}, 131 | {"message": {"content": '{"input": "1", "output": "b"}'}}, 132 | ] 133 | self.mock_completions.append([mock_completion_4]) 134 | mock_completion_5 = MockCompletion() 135 | mock_completion_5.choices = [ 136 | {"message": {"content": '{"input": "4", "output": "c"}'}}, 137 | {"message": {"content": '{"input": "4", "output": "c"}'}}, 138 | {"message": {"content": '{"input": "5", "output": "a"}'}}, 139 | ] 140 | self.mock_completions.append([mock_completion_5]) 141 | if length == 5: 142 | self.mock_completions.append( 143 | [ 144 | mock_completion_1, 145 | mock_completion_2, 146 | ] 147 | ) 148 | 149 | 150 | def mock_batch_api_response_identical_completions( 151 | prompts: list[str], 152 | content: str, 153 | temperature: float, 154 | presence_penalty: float = 0, 155 | frequency_penalty: float = 0, 156 | responses_per_request: int = 5, 157 | requests_per_minute: int = 80, 158 | ) -> list[MockCompletion]: 159 | """Generate a batch of mock completion objects. 160 | 161 | This function creates a batch of `MockCompletion` 162 | object with a `content` attribute set to an LLM completion string. 163 | 164 | Args: 165 | prompts: A batch of mocked prompts that won't be used. 166 | content: The example string to be returned. 167 | temperature: A mocked temperature. 168 | presence_penalty: A mocked presence penalty. 169 | frequency_penalty: A mocked frequency penalty. 170 | responses_per_request: Number of responses for each request. 171 | requests_per_minute: Number of requests per minute to allow. 172 | 173 | Returns: 174 | A mock completion object simulating an ChatCompletion API response. 175 | """ 176 | _ = prompts, temperature, presence_penalty, frequency_penalty, requests_per_minute 177 | mock_completions = [ 178 | MockCompletion(content=content, responses_per_request=responses_per_request) 179 | for _ in prompts 180 | ] 181 | return mock_completions 182 | 183 | 184 | class MockAPIAgent(APIAgent): 185 | """A mock API agent that always returns the same content.""" 186 | 187 | def __init__(self, default_content): 188 | """Initialize the API agent.""" 189 | self.generate_one_call_counter = 0 190 | self.generate_batch_call_counter = 0 191 | self.default_content = default_content 192 | 193 | def generate_one_completion( 194 | self, 195 | prompt: str, 196 | temperature: float = 0, 197 | presence_penalty: float = 0, 198 | frequency_penalty: float = 0, 199 | token_buffer: int = 300, 200 | ) -> openai.Completion: 201 | """Return a mocked object and increment the counter.""" 202 | self.generate_one_call_counter += 1 203 | return MockCompletion(content=self.default_content) 204 | 205 | async def generate_batch_completion( 206 | self, 207 | prompts: list[str], 208 | temperature: float = 1, 209 | responses_per_request: int = 5, 210 | requests_per_minute: int = 80, 211 | token_buffer: int = 300, 212 | ) -> list[openai.Completion]: 213 | """Return a mocked object and increment the counter.""" 214 | self.generate_batch_call_counter += 1 215 | return [MockCompletion(content=self.default_content) for _ in prompts] 216 | 217 | 218 | class UnknownGpt3Exception(Exception): 219 | """This is a newly-defined exception for testing purposes.""" 220 | 221 | pass 222 | -------------------------------------------------------------------------------- /test_helpers/mock_retrieval.py: -------------------------------------------------------------------------------- 1 | """Tools for creating a mock search index.""" 2 | 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | 8 | def create_test_search_index(index_file_name: str) -> None: 9 | """Utility function to create a test search index. 10 | 11 | This search index represents 3 models, each represented with a hand-written vector. 12 | Given a query of [0, 0, 1], the 3rd model will be the most similar. 13 | 14 | Args: 15 | index_file_name: The name of the file to which a pickled index will be written. 16 | """ 17 | mock_model_encodings = np.array([[0.9, 0, 0], [0, 0.9, 0], [0, 0, 0.9]]) 18 | mock_lookup_indices = [0, 1, 2] 19 | with open(index_file_name, "wb") as f: 20 | pickle.dump((mock_model_encodings, mock_lookup_indices), f) 21 | -------------------------------------------------------------------------------- /test_helpers/model_and_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tools for creating a GPT-2 model with padding tokenizer.""" 2 | 3 | from collections import namedtuple 4 | 5 | from transformers import AutoModelForCausalLM, AutoTokenizer, T5ForConditionalGeneration 6 | 7 | ModelAndTokenizer = namedtuple("ModelAndTokenizer", ["model", "tokenizer"]) 8 | 9 | 10 | def create_gpt2_model_and_tokenizer(full_size: bool = False) -> ModelAndTokenizer: 11 | """Create a GPT2 model with its padding tokenizer for batched input. 12 | 13 | Args: 14 | full_size: Whether to use the full size of the GPT-2 model. Defaults to False. 15 | Note that the full size of the GPT-2 model may occupy too much memory 16 | that lead to out of memory errors. 17 | 18 | Returns: 19 | gpt2_model_and_tokenizer: A namedtuple with gpt2 model and tokenizer. 20 | """ 21 | if not full_size: 22 | gpt2_model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") 23 | gpt2_tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2") 24 | else: 25 | gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2") 26 | gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2") 27 | if gpt2_tokenizer.pad_token is None: 28 | gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token 29 | if gpt2_model.config.pad_token_id is None: 30 | gpt2_model.config.pad_token_id = gpt2_tokenizer.eos_token_id 31 | gpt2_model_and_tokenizer = ModelAndTokenizer(gpt2_model, gpt2_tokenizer) 32 | return gpt2_model_and_tokenizer 33 | 34 | 35 | def create_t5_model_and_tokenizer(full_size: bool = False) -> ModelAndTokenizer: 36 | """Create a T5 model with its padding tokenizer for batched input. 37 | 38 | Args: 39 | full_size: Whether to use the full size of the T5 model. Defaults to False. 40 | Note that the full size of the T5 model may occupy too much memory 41 | that lead to out of memory errors. 42 | 43 | Returns: 44 | t5_model_and_tokenizer: A namedtuple with t5 model and tokenizer. 45 | """ 46 | if not full_size: 47 | t5_model = T5ForConditionalGeneration.from_pretrained( 48 | "google/t5-efficient-tiny" 49 | ) 50 | t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-efficient-tiny") 51 | else: 52 | t5_model = T5ForConditionalGeneration.from_pretrained("t5-small") 53 | t5_tokenizer = AutoTokenizer.from_pretrained("t5-small") 54 | t5_model_and_tokenizer = ModelAndTokenizer(t5_model, t5_tokenizer) 55 | return t5_model_and_tokenizer 56 | -------------------------------------------------------------------------------- /test_helpers/model_info_tiny/gpt2.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_model_name": "gpt2", 3 | "description": "unidirectional autoregressive language model", 4 | "size_bytes": "800000000", 5 | "downloads": 100, 6 | "_comment": "THIS FILE SHOULD ONLY BE USED FOR TESTING PURPOSES." 7 | } -------------------------------------------------------------------------------- /test_helpers/model_info_tiny/t5-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_model_name": "t5-base", 3 | "description": "text to text generator", 4 | "size_bytes": "800000000", 5 | "downloads": 10000, 6 | "_comment": "THIS FILE SHOULD ONLY BE USED FOR TESTING PURPOSES." 7 | } -------------------------------------------------------------------------------- /test_helpers/model_info_tiny/xlnet-base-cased.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_model_name": "xlnet-base-cased", 3 | "description": "bidirectional autoregressive language model", 4 | "size_bytes": "1600000000", 5 | "downloads": 100000, 6 | "_comment": "THIS FILE SHOULD ONLY BE USED FOR TESTING PURPOSES." 7 | } -------------------------------------------------------------------------------- /test_helpers/test_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for testing.""" 2 | from contextlib import contextmanager 3 | 4 | 5 | @contextmanager 6 | def temp_setattr(obj, attr, value): 7 | """Temporarily set an attribute on an object.""" 8 | original = getattr(obj, attr, None) 9 | setattr(obj, attr, value) 10 | try: 11 | yield 12 | finally: 13 | if original is not None: 14 | setattr(obj, attr, original) 15 | else: 16 | delattr(obj, attr) 17 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Init file for test scripts.""" 2 | -------------------------------------------------------------------------------- /tests/dataset_transformer_test.py: -------------------------------------------------------------------------------- 1 | """Tests for dataset_transformer.""" 2 | 3 | from functools import partial 4 | from unittest.mock import patch 5 | 6 | from datasets import Dataset, DatasetDict 7 | 8 | from prompt2model.dataset_transformer import PromptBasedDatasetTransformer 9 | from prompt2model.prompt_parser import MockPromptSpec, TaskType 10 | from test_helpers import MockCompletion, mock_batch_api_response_identical_completions 11 | 12 | TASK_EXPLANATION = MockCompletion( 13 | """This is a question answering task. The model is given a context and a question, and it must provide the answer.""" # noqa E501 14 | ) 15 | PLAN_RESPONSE = MockCompletion( 16 | """plan: 17 | 1. Combine "context" and "question" into "input". 18 | 2. Combine "answer" into "output".""" 19 | ) 20 | 21 | TRANSFORMED_DATA = partial( 22 | mock_batch_api_response_identical_completions, 23 | content='{"input": "context: Albany, the state capital, is the sixth-largest city in the State of New York.\nquestion: What is the capital of New York?", "output": "Albany"}', # noqa E501 24 | ) 25 | 26 | 27 | @patch( 28 | "prompt2model.utils.APIAgent.generate_batch_completion", 29 | side_effect=TRANSFORMED_DATA, 30 | ) 31 | @patch( 32 | "prompt2model.utils.APIAgent.generate_one_completion", 33 | side_effect=[TASK_EXPLANATION, PLAN_RESPONSE], 34 | ) 35 | def test_transform_data(mock_batch_completion, mock_one_completion): 36 | """Test transform_data.""" 37 | dataset_transformer = PromptBasedDatasetTransformer( 38 | num_points_to_transform=1, max_allowed_failed_transforms=0 39 | ) # noqa: E501 40 | prompt_spec = MockPromptSpec( 41 | TaskType.TEXT_GENERATION, 42 | instruction="instruction", 43 | examples="example1\nexample2", 44 | ) 45 | mock_dataset = Dataset.from_dict( 46 | { 47 | "question": ["What is the capital of New York?"], 48 | "context": [ 49 | "Albany, the state capital, is the sixth-largest city in the State of New York." # noqa E501 50 | ], 51 | "answer": ["Albany"], 52 | } 53 | ) 54 | # Create a mock DatasetDict consisting of the same example in each split. 55 | dataset = DatasetDict( 56 | {"train": mock_dataset, "val": mock_dataset, "test": mock_dataset} 57 | ) 58 | inputs, outputs = dataset_transformer.transform_data( 59 | prompt_spec=prompt_spec, 60 | dataset=dataset["train"], 61 | ) 62 | assert inputs == [ 63 | "context: Albany, the state capital, is the sixth-largest city in the State of New York.\nquestion: What is the capital of New York?" # noqa E501 64 | ] # noqa E501 65 | assert outputs == ["Albany"] 66 | -------------------------------------------------------------------------------- /tests/demo_creator_test.py: -------------------------------------------------------------------------------- 1 | """Test the create_gradio function with two configurations.""" 2 | 3 | import gc 4 | 5 | import gradio as gr 6 | 7 | from prompt2model.demo_creator import create_gradio 8 | from prompt2model.model_executor import GenerationModelExecutor 9 | from prompt2model.prompt_parser import MockPromptSpec, TaskType 10 | from test_helpers import create_gpt2_model_and_tokenizer, create_t5_model_and_tokenizer 11 | 12 | 13 | def test_create_gradio_with_gpt2(): 14 | """Test the `create_gradio` method with a GPT2 model.""" 15 | # Create a GPT-2 model and tokenizer. 16 | gpt2_model_and_tokenizer = create_gpt2_model_and_tokenizer() 17 | gpt2_model = gpt2_model_and_tokenizer.model 18 | gpt2_tokenizer = gpt2_model_and_tokenizer.tokenizer 19 | 20 | # Create a GenerationModelExecutor. 21 | gpt2_executor = GenerationModelExecutor( 22 | model=gpt2_model, 23 | tokenizer=gpt2_tokenizer, 24 | batch_size=1, 25 | ) 26 | 27 | # Create PromptBasedInstructionParser. 28 | gpt2_prompt_parser = MockPromptSpec(TaskType.TEXT_GENERATION) 29 | 30 | # Create Gradio interface. 31 | interface_gpt2 = create_gradio(gpt2_executor, gpt2_prompt_parser) 32 | 33 | # Perform assertions. 34 | assert isinstance(interface_gpt2, gr.Blocks) 35 | gc.collect() 36 | 37 | 38 | def test_create_gradio_with_t5(): 39 | """Test the `create_gradio` method with a T5 model.""" 40 | # Create T5 model and tokenizer 41 | t5_model_and_tokenizer = create_t5_model_and_tokenizer() 42 | t5_model = t5_model_and_tokenizer.model 43 | t5_tokenizer = t5_model_and_tokenizer.tokenizer 44 | 45 | # Create a GenerationModelExecutor. 46 | t5_executor = GenerationModelExecutor( 47 | model=t5_model, 48 | tokenizer=t5_tokenizer, 49 | batch_size=1, 50 | ) 51 | 52 | # Create PromptBasedInstructionParser. 53 | t5_prompt_parser = MockPromptSpec(task_type="generation") 54 | 55 | # Create Gradio interface. 56 | interface_t5 = create_gradio(t5_executor, t5_prompt_parser) 57 | 58 | # Perform assertions. 59 | assert isinstance(interface_t5, gr.Blocks) 60 | gc.collect() 61 | -------------------------------------------------------------------------------- /tests/param_selector_test.py: -------------------------------------------------------------------------------- 1 | """Testing hyperparameter optimization with different configurations.""" 2 | 3 | import gc 4 | import logging 5 | import os 6 | 7 | import datasets 8 | 9 | from prompt2model.model_trainer.generate import GenerationModelTrainer 10 | from prompt2model.param_selector.search_with_optuna import OptunaParamSelector 11 | 12 | os.environ["WANDB_MODE"] = "dryrun" 13 | logger = logging.getLogger("AutoHyperparamOptimization") 14 | 15 | 16 | def test_optimize_hyperparameters(): 17 | """Tests whether the hyperparameter optimization is working correctly or not.""" 18 | # Create a simple training dataset. 19 | training_datasets = [ 20 | datasets.Dataset.from_dict( 21 | { 22 | "model_input": [ 23 | "Given a product review, predict the sentiment score associated with it.\nExample:\nIt isn’t my fav lip balm, but it’s up there. It moisturises really well and the lemon isn’t strong or over powering.\nLabel:\n", # noqa E501 24 | ], 25 | "model_output": ["4"], 26 | } 27 | ), 28 | ] 29 | 30 | validation_datasets = [ 31 | datasets.Dataset.from_dict( 32 | { 33 | "model_input": [ 34 | "Given a product review, predict the sentiment score associated with it.\nExample:\nBroke me out and gave me awful texture all over my face. I typically have clear skin and after using this product my skin HATED it. Could work for you though.\nLabel:\n", # noqa E501 35 | ], 36 | "model_output": ["2"], 37 | } 38 | ), 39 | ] 40 | 41 | param_selector = OptunaParamSelector( 42 | n_trials=1, 43 | trainer=GenerationModelTrainer( 44 | "patrickvonplaten/t5-tiny-random", has_encoder=True 45 | ), 46 | ) 47 | best_hyperparameters = param_selector.optimize_hyperparameters( 48 | training_datasets=training_datasets, 49 | validation=validation_datasets, 50 | hyperparameters={ 51 | "min_num_train_epochs": 1, 52 | "max_num_train_epochs": 1, 53 | "save_strategy": ["epoch"], 54 | "evaluation_strategy": ["epoch"], 55 | "per_device_train_batch_size": [2], 56 | "min_weight_decay": 4e-5, 57 | "max_weight_decay": 1e-1, 58 | "min_learning_rate": 4e-5, 59 | "max_learning_rate": 1e-1, 60 | }, 61 | ) 62 | assert isinstance(best_hyperparameters, dict) 63 | gc.collect() 64 | -------------------------------------------------------------------------------- /tests/run_locally_test.py: -------------------------------------------------------------------------------- 1 | """Testing integration of components locally.""" 2 | 3 | import gc 4 | import os 5 | import tempfile 6 | 7 | from prompt2model.run_locally import run_skeleton 8 | 9 | 10 | def test_integration(): 11 | """Check that a end-to-end run with a single prompt doesn't throw an error.""" 12 | prompt = ["Test prompt"] 13 | with tempfile.TemporaryDirectory() as tmpdirname: 14 | metrics_output_path = os.path.join(tmpdirname, "metrics.json") 15 | run_skeleton(prompt, metrics_output_path) 16 | gc.collect() 17 | -------------------------------------------------------------------------------- /tests/tevatron_utils_test.py: -------------------------------------------------------------------------------- 1 | """Testing DatasetGenerator through PromptBasedDatasetGenerator.""" 2 | 3 | import gc 4 | import json 5 | import os 6 | import pickle 7 | import tempfile 8 | 9 | import numpy as np 10 | import pytest 11 | from tevatron.modeling import DenseModelForInference 12 | from transformers import PreTrainedTokenizerBase 13 | 14 | from prompt2model.utils.tevatron_utils import encode_text, retrieve_objects 15 | from prompt2model.utils.tevatron_utils.encode import load_tevatron_model 16 | 17 | TINY_MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2" 18 | 19 | 20 | def test_load_tevatron_model(): 21 | """Test loading a small Tevatron model.""" 22 | model, tokenizer = load_tevatron_model(TINY_MODEL_NAME) 23 | assert isinstance(model, DenseModelForInference) 24 | assert isinstance(tokenizer, PreTrainedTokenizerBase) 25 | gc.collect() 26 | 27 | 28 | def test_encode_text_from_string(): 29 | """Test encoding text from a string into a vector.""" 30 | text = "This is an example sentence" 31 | encoded = encode_text(TINY_MODEL_NAME, text_to_encode=text) 32 | assert encoded.shape == (1, 128) 33 | gc.collect() 34 | 35 | 36 | def test_encode_text_from_file(): 37 | """Test encoding text from a file into a vector.""" 38 | text_rows = [ 39 | {"text_id": 0, "text": "This is an example sentence"}, 40 | {"text_id": 1, "text": "This is another example sentence"}, 41 | ] 42 | with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as f: 43 | json.dump(text_rows, f) 44 | f.seek(0) 45 | encoded = encode_text(TINY_MODEL_NAME, file_to_encode=f.name) 46 | assert encoded.shape == (2, 128) 47 | gc.collect() 48 | 49 | 50 | def test_encode_text_from_file_store_to_file(): 51 | """Test encoding text from a file into a vector, then stored to file.""" 52 | text_rows = [ 53 | {"text_id": 0, "text": "This is an example sentence"}, 54 | {"text_id": 1, "text": "This is another example sentence"}, 55 | ] 56 | with tempfile.TemporaryDirectory() as tempdir: 57 | with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as f: 58 | json.dump(text_rows, f) 59 | f.seek(0) 60 | encoding_file_path = os.path.join(tempdir, "encoding.pkl") 61 | encoded = encode_text( 62 | TINY_MODEL_NAME, file_to_encode=f.name, encoding_file=encoding_file_path 63 | ) 64 | assert encoded.shape == (2, 128) 65 | encoded_vectors, encoded_indices = pickle.load( 66 | open(encoding_file_path, "rb") 67 | ) 68 | assert (encoded == encoded_vectors).all() 69 | assert encoded_indices == [0, 1] 70 | gc.collect() 71 | 72 | 73 | def test_encode_text_error_from_no_string_or_file(): 74 | """Test that either a string or a file must be passed to encode.""" 75 | with pytest.raises(ValueError): 76 | _ = encode_text(TINY_MODEL_NAME) 77 | gc.collect() 78 | 79 | 80 | def test_encode_text_error_from_both_string_and_file(): 81 | """Test that either a string or a file, but not both, must be passed to encode.""" 82 | text = "This is an example sentence" 83 | file = "/tmp/test.txt" 84 | with pytest.raises(ValueError): 85 | _ = encode_text(TINY_MODEL_NAME, file_to_encode=file, text_to_encode=text) 86 | gc.collect() 87 | 88 | 89 | def test_retrieve_objects(): 90 | """Test retrieval against a list of vectors.""" 91 | mock_query_vector = np.array([[0.0, 0.0, 1.0, 0.0]]) 92 | # The query vector matches the third row in the search collection. 93 | mock_search_collection = np.array( 94 | [ 95 | [1.0, 0.0, 0.0, 0.0], 96 | [0.0, 1.0, 0.0, 0.0], 97 | [0.0, 0.0, 1.0, 0.0], 98 | [0.0, 0.0, 0.0, 1.0], 99 | ] 100 | ) 101 | document_names = ["a", "b", "c", "d"] 102 | mock_vector_indices = [0, 1, 2, 3] 103 | with tempfile.TemporaryDirectory() as tmpdir: 104 | search_index_pickle = os.path.join(tmpdir, "search_index.pkl") 105 | pickle.dump( 106 | (mock_search_collection, mock_vector_indices), 107 | open(search_index_pickle, "wb"), 108 | ) 109 | results = retrieve_objects( 110 | mock_query_vector, search_index_pickle, document_names, depth=3 111 | ) 112 | assert ( 113 | len(results) == 3 114 | ), "The number of results should match the provided depth." 115 | 116 | # Verify that the index of the first retrieved document matches the document 117 | # that we known matches the query vector. 118 | first_retrieved_document, _ = results[0] 119 | assert first_retrieved_document == "c" 120 | 121 | # Verify that the first retrieved document has the greatest retrieval score. 122 | sorted_results = sorted(results, key=lambda x: x[1], reverse=True) 123 | assert sorted_results[0][0] == first_retrieved_document 124 | gc.collect() 125 | --------------------------------------------------------------------------------