├── .flake8
├── .github
├── pull_request_template.md
└── workflows
│ ├── base_ci.yml
│ ├── ci.yml
│ └── release.yml
├── .gitignore
├── .gitmodules
├── .markdownlint.yaml
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── examples
├── create_synthetic_data_example.py
├── create_transform_data_example.py
└── mistral_qlora_finetune_example.py
├── mypy.ini
├── prompt2model
├── __init__.py
├── dataset_generator
│ ├── __init__.py
│ ├── base.py
│ ├── mock.py
│ ├── prompt_based.py
│ ├── prompt_template.py
│ └── readme.md
├── dataset_processor
│ ├── __init__.py
│ ├── base.py
│ ├── mock.py
│ ├── readme.md
│ └── textualize.py
├── dataset_retriever
│ ├── __init__.py
│ ├── base.py
│ ├── column_selection_prompt.py
│ ├── description_dataset_retriever.py
│ ├── mock.py
│ ├── readme.md
│ ├── reranking_prompt.py
│ ├── run_dataset_retriever.py
│ └── task_expansion_prompt.py
├── dataset_transformer
│ ├── __init__.py
│ ├── base.py
│ ├── prompt_based.py
│ └── prompt_template.py
├── demo_creator
│ ├── __init__.py
│ ├── create.py
│ ├── mock.py
│ └── readme.md
├── model_evaluator
│ ├── __init__.py
│ ├── base.py
│ ├── mock.py
│ ├── readme.md
│ └── seq2seq.py
├── model_executor
│ ├── __init__.py
│ ├── base.py
│ ├── generate.py
│ ├── mock.py
│ └── readme.md
├── model_retriever
│ ├── __init__.py
│ ├── base.py
│ ├── description_based_retriever.py
│ ├── generate_hypothetical_document.py
│ ├── mock.py
│ ├── readme.md
│ └── run_model_retriever.py
├── model_trainer
│ ├── __init__.py
│ ├── base.py
│ ├── callback.py
│ ├── generate.py
│ ├── mock.py
│ ├── qlora_trainer.py
│ └── readme.md
├── param_selector
│ ├── __init__.py
│ ├── base.py
│ ├── mock.py
│ └── search_with_optuna.py
├── prompt_parser
│ ├── __init__.py
│ ├── base.py
│ ├── instr_parser.py
│ ├── instr_parser_prompt.py
│ ├── mock.py
│ └── readme.md
├── run_locally.py
├── utils
│ ├── __init__.py
│ ├── api_tools.py
│ ├── config.py
│ ├── dataset_utils.py
│ ├── dataset_utils_test.py
│ ├── logging_utils.py
│ ├── parse_responses.py
│ ├── retrieve_model_info.py
│ ├── rng.py
│ └── tevatron_utils
│ │ ├── __init__.py
│ │ ├── encode.py
│ │ └── retrieve.py
└── version.py
├── prompt2model_demo.ipynb
├── prompt2model_demo.py
├── prompt_examples.md
├── pyproject.toml
├── scripts
└── dataset_index
│ ├── preprocessing.py
│ └── retrieve_dataset_info.py
├── test_helpers
├── __init__.py
├── dataset_index_tiny.json
├── mock_api.py
├── mock_retrieval.py
├── model_and_tokenizer.py
├── model_info_tiny
│ ├── gpt2.json
│ ├── t5-base.json
│ └── xlnet-base-cased.json
├── reranking_dataset_index_tiny.json
└── test_utils.py
└── tests
├── __init__.py
├── dataset_generator_test.py
├── dataset_processor_test.py
├── dataset_retriever_test.py
├── dataset_transformer_test.py
├── demo_creator_test.py
├── model_evaluator_test.py
├── model_executor_for_gpt_test.py
├── model_executor_for_t5_test.py
├── model_retriever_test.py
├── model_trainer_for_gpt_test.py
├── model_trainer_for_t5_test.py
├── param_selector_test.py
├── prompt_parser_test.py
├── run_locally_test.py
└── tevatron_utils_test.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18,BLK100,W503
4 | per-file-ignores = prompt2model/dataset_transformer/prompt_template.py:E501
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Description
4 |
5 |
9 |
10 | # References
11 |
12 |
13 |
14 | - Foo
15 | - Bar
16 | - Baz
17 |
18 | # Blocked by
19 |
20 |
21 |
22 | - NA
23 | - (or link to PRs)
24 |
--------------------------------------------------------------------------------
/.github/workflows/base_ci.yml:
--------------------------------------------------------------------------------
1 | name: Reusable Continuous Integration Workflow
2 | on:
3 | workflow_call:
4 |
5 | jobs:
6 | unit-tests:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | python-version: [ '3.9', '3.10', '3.11' ]
11 | steps:
12 | - uses: actions/checkout@v3
13 | - name: Install Python ${{ matrix.python-version }}
14 | uses: actions/setup-python@v3
15 | with:
16 | python-version: ${{ matrix.python-version }}
17 | cache: 'pip'
18 | cache-dependency-path: 'pyproject.toml'
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install .
23 | - name: test
24 | run: pytest
25 | format:
26 | runs-on: ubuntu-latest
27 | steps:
28 | - uses: actions/checkout@v3
29 | - name: Install Python 3
30 | uses: actions/setup-python@v3
31 | with:
32 | python-version: 3.9
33 | - name: format
34 | uses: pre-commit/action@v3.0.0
35 | with:
36 | extra_args: black --all-files
37 | isort:
38 | runs-on: ubuntu-latest
39 | steps:
40 | - uses: actions/checkout@v3
41 | - name: Install Python 3
42 | uses: actions/setup-python@v3
43 | with:
44 | python-version: 3.9
45 | - name: isort
46 | uses: pre-commit/action@v3.0.0
47 | with:
48 | extra_args: isort --all-files
49 | lint:
50 | runs-on: ubuntu-latest
51 | steps:
52 | - uses: actions/checkout@v3
53 | - name: Install Python 3
54 | uses: actions/setup-python@v3
55 | with:
56 | python-version: 3.9
57 | - name: lint
58 | uses: pre-commit/action@v3.0.0
59 | with:
60 | extra_args: flake8 --all-files
61 | - name: lint-pep585-compliant
62 | uses: pre-commit/action@v3.0.0
63 | with:
64 | extra_args: upgrade-type-hints --all-files
65 | typecheck:
66 | runs-on: ubuntu-latest
67 | steps:
68 | - uses: actions/checkout@v3
69 | - name: Install Python 3
70 | uses: actions/setup-python@v3
71 | with:
72 | python-version: 3.9
73 | - name: mypy
74 | uses: pre-commit/action@v3.0.0
75 | with:
76 | extra_args: mypy --all-files
77 | markdownlint:
78 | runs-on: ubuntu-latest
79 | steps:
80 | - uses: actions/checkout@v3
81 | - name: Install Node
82 | uses: actions/setup-node@v3
83 | with:
84 | node-version: 18.x
85 | - name: lint
86 | uses: pre-commit/action@v3.0.0
87 | with:
88 | extra_args: markdownlint-cli2 --all-files
89 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Continuous Integration
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches: ["**"]
8 |
9 | jobs:
10 | unit-tests:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: [ '3.9', '3.10', '3.11' ]
15 | steps:
16 | - uses: actions/checkout@v3
17 | - name: Install Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v3
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | cache: 'pip'
22 | cache-dependency-path: 'pyproject.toml'
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install .
27 | - name: test
28 | run: pytest
29 | format:
30 | runs-on: ubuntu-latest
31 | steps:
32 | - uses: actions/checkout@v3
33 | - name: Install Python 3
34 | uses: actions/setup-python@v3
35 | with:
36 | python-version: 3.9
37 | - name: format
38 | uses: pre-commit/action@v3.0.0
39 | with:
40 | extra_args: black --all-files
41 | isort:
42 | runs-on: ubuntu-latest
43 | steps:
44 | - uses: actions/checkout@v3
45 | - name: Install Python 3
46 | uses: actions/setup-python@v3
47 | with:
48 | python-version: 3.9
49 | - name: isort
50 | uses: pre-commit/action@v3.0.0
51 | with:
52 | extra_args: isort --all-files
53 | lint:
54 | runs-on: ubuntu-latest
55 | steps:
56 | - uses: actions/checkout@v3
57 | - name: Install Python 3
58 | uses: actions/setup-python@v3
59 | with:
60 | python-version: 3.9
61 | - name: lint
62 | uses: pre-commit/action@v3.0.0
63 | with:
64 | extra_args: flake8 --all-files
65 | - name: lint-pep585-compliant
66 | uses: pre-commit/action@v3.0.0
67 | with:
68 | extra_args: upgrade-type-hints --all-files
69 | typecheck:
70 | runs-on: ubuntu-latest
71 | steps:
72 | - uses: actions/checkout@v3
73 | - name: Install Python 3
74 | uses: actions/setup-python@v3
75 | with:
76 | python-version: 3.9
77 | - name: mypy
78 | uses: pre-commit/action@v3.0.0
79 | with:
80 | extra_args: mypy --all-files
81 | markdownlint:
82 | runs-on: ubuntu-latest
83 | steps:
84 | - uses: actions/checkout@v3
85 | - name: Install Node
86 | uses: actions/setup-node@v3
87 | with:
88 | node-version: 18.x
89 | - name: lint
90 | uses: pre-commit/action@v3.0.0
91 | with:
92 | extra_args: markdownlint-cli2 --all-files
93 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release workflow
2 |
3 | on:
4 | push:
5 | tags:
6 | - "v[0123456789].*"
7 |
8 | jobs:
9 | release:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: checkout
13 | uses: actions/checkout@v3
14 | - name: setup python
15 | uses: actions/setup-python@v2
16 | with:
17 | python-version: "3.x"
18 | - name: build
19 | run: |
20 | python -m pip install --upgrade build hatch
21 | python -m hatch version "${GITHUB_REF_NAME}"
22 | python -m build
23 | - name: publish
24 | uses: pypa/gh-action-pypi-publish@release/v1
25 | with:
26 | password: ${{ secrets.PYPI_API_TOKEN }}
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | build
3 | dist
4 | prompt2model.egg-info
5 | .env
6 | .vscode
7 | .mypy_cache
8 | .pytest_cache
9 | *.pyc
10 | flagged
11 | wandb
12 | logs
13 | temp
14 | tests/logs
15 | tests/wandb
16 |
17 | # Outputs generated by the cli demo
18 | cached_generated_dataset/
19 | generated_dataset/
20 | huggingface_data/huggingface_datasets/dataset_index.json
21 | huggingface_data/huggingface_datasets/huggingface_datasets_datafinder_index
22 | huggingface_data/huggingface_datasets/reranking_dataset_index.json
23 | huggingface_data/huggingface_models/
24 | retrieved_dataset_dict/
25 | result/
26 | checkpoint/
27 | status.yaml
28 | # Outputs generated by the colab demo
29 | trained_model/
30 | trained_tokenizer/
31 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neulab/prompt2model/d4cd23d50c8f5edc5755ab487bcee18121d42a4f/.gitmodules
--------------------------------------------------------------------------------
/.markdownlint.yaml:
--------------------------------------------------------------------------------
1 | MD013:
2 | line_length: 120
3 | code_blocks: false
4 | MD033: false
5 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/python/black.git
3 | rev: 22.3.0
4 | hooks:
5 | - id: black
6 | files: '\.py$'
7 | - repo: https://github.com/PyCQA/flake8
8 | rev: 5.0.4
9 | hooks:
10 | - id: flake8
11 | name: flake8
12 | additional_dependencies:
13 | - flake8-absolute-import
14 | - flake8-black>=0.1.1
15 | - flake8-pep585>=0.1.6
16 | entry: flake8
17 | files: '\.py$'
18 | - id: flake8
19 | name: docstring
20 | additional_dependencies:
21 | - flake8-docstrings>=1.6
22 | args:
23 | - --docstring-convention=google
24 | - --select=D
25 | entry: flake8
26 | files: '\.py$'
27 | - id: flake8
28 | name: future-import
29 | additional_dependencies:
30 | - flake8-future-import
31 | args:
32 | - --select=
33 | - --ignore FI58
34 | entry: flake8
35 | files: '\.py$'
36 | - repo: https://github.com/pycqa/isort.git
37 | rev: 5.12.0
38 | hooks:
39 | - id: isort
40 | args: ["--profile", "black"]
41 |
42 | files: '\.py$'
43 | - repo: https://github.com/sondrelg/pep585-upgrade
44 | rev: v1.0.1
45 | hooks:
46 | - id: upgrade-type-hints
47 | files: '\.py$'
48 | - repo: https://github.com/pre-commit/mirrors-mypy
49 | rev: 'v0.981'
50 | hooks:
51 | - id: mypy
52 | additional_dependencies:
53 | - types-requests
54 | files: '\.py$'
55 | - repo: https://github.com/DavidAnson/markdownlint-cli2
56 | rev: v0.5.1
57 | hooks:
58 | - id: markdownlint-cli2
59 | - id: markdownlint-cli2-fix
60 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to prompt2model
2 |
3 | Thanks for your interest in contributing to prompt2model!
4 | We appreciate your support and welcome your
5 | contributions. Here's a guide to help you get started:
6 |
7 | ## Developer Installation
8 |
9 | Installing dependencies:
10 | Run the following command to install all the dependencies after you have forked the repository.
11 |
12 | ```bash
13 | pip install .[dev]
14 | ```
15 |
16 | If you're a developer, it's recommended to install
17 | pre-commit hooks before starting your development
18 | work. These hooks ensure code formatting, linting,
19 | and type-checking. To install the pre-commit hooks,
20 | run the following command:
21 |
22 | ```bash
23 | pre-commit install
24 | ```
25 |
26 | Additionally, it's essential to run tests to verify the
27 | functionality of your code. Execute the following
28 | command to run the tests:
29 |
30 | ```bash
31 | pytest
32 | ```
33 |
34 | ## Contribution Guide
35 |
36 | To contribute to prompt2model, or if you have any questions,
37 | please reach out to us!
38 |
39 | - open an [issue](https://github.com/neulab/prompt2model/issues) or submit a PR
40 | - join us on [discord](https://discord.gg/UCy9csEmFc)
41 | - or reach out to [@vijaytarian](https://twitter.com/vijaytarian)
42 | and [@Chenan3_Zhao](https://twitter.com/Chenan3_Zhao) on Twitter.
43 |
44 | ## Making a Release
45 |
46 | If you have admin privileges for the repository,
47 | you can create a new release of the prompt2model
48 | library. We utilize the
49 | [hatchling](https://github.com/pypa/hatch) build
50 | system, which simplifies the process of making
51 | new releases.
52 |
53 | To create a new release, follow these steps:
54 |
55 | 1. Create a new version tag on GitHub, adhering to
56 | the [semantic versioning](https://semver.org/) guidelines.
57 | 2. Once the tag is created, the continuous integration
58 | (CI) system will automatically build and publish the
59 | new version to PyPI.
60 |
61 | By following these steps, you can effectively make
62 | new releases of the library and contribute to its
63 | ongoing development.
64 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Prompt2Model - Generate Deployable Models from Instructions
2 |
3 | [](https://badge.fury.io/py/prompt2model)
4 | 
5 | [](https://lbesson.mit-license.org/)
6 | [](https://discord.gg/UCy9csEmFc)
7 | [](https://colab.research.google.com/github/neulab/prompt2model/blob/main/prompt2model_demo.ipynb)
8 |
9 | `Prompt2Model` is a system that takes a natural
10 | language task description (like the prompts used for
11 | LLMs such as ChatGPT) to train a small
12 | special-purpose model that is conducive for deployment.
13 |
14 |
15 |
16 | ## Quick Start
17 |
18 | ### Notebook
19 |
20 | You can run our demo of `Prompt2Model` through a notebook:
21 |
22 | - [Open Locally](./prompt2model_demo.ipynb)
23 | - [Open in Colab](https://colab.research.google.com/github/neulab/prompt2model/blob/main/prompt2model_demo.ipynb)
24 |
25 | ### Command Line
26 |
27 | You can also run through the command line.
28 |
29 | ```bash
30 | pip install prompt2model
31 | ```
32 |
33 | `Prompt2Model` supports various platforms such as OpenAI, Anthropic, Huggingface, etc. using [LiteLLM](https://github.com/BerriAI/litellm).
34 |
35 | If you are using OpenAI models (such as the default `gpt-3.5-turbo`), please obtain an
36 | OpenAI API key on their [website](https://platform.openai.com/) then set
37 | the environment variable `OPENAI_API_KEY` to your API key by running
38 | the following command in your terminal:
39 |
40 | ```bash
41 | export OPENAI_API_KEY=
42 | ```
43 |
44 | [List of all supported providers](https://docs.litellm.ai/docs/providers)
45 |
46 | You can then run
47 |
48 | ```bash
49 | python prompt2model_demo.py
50 | ```
51 |
52 | to create a small model from a prompt, as shown in
53 | the demo video below. This script must be run on a
54 | device with an internet connection to access the OpenAI
55 | API. For best results, run
56 | this script on a device with a GPU for training
57 | your model.
58 |
59 | ## Demo
60 |
61 |
62 |
63 | ## Tips and Examples to Write a Good Prompt
64 |
65 | You can see the tips and examples to write
66 | a good prompt in [prompt_examples](./prompt_examples.md).
67 |
68 | ## Components
69 |
70 | The `prompt2model` package is composed
71 | of several components, each designed
72 | to fulfill a specific purpose. To gain
73 | a comprehensive understanding of how to
74 | utilize each component effectively,
75 | please consult the `readme.md` file
76 | situated in the directory of the respective
77 | component. These files can be found at
78 | `./prompt2model//readme.md`.
79 | They provide detailed information and
80 | instructions on customizing and maximizing
81 | the functionality of each
82 | component within the package.
83 |
84 | ## Contribution
85 |
86 | If you're interested in contributing to the `prompt2model` project, please
87 |
88 | - refer to [CONTRIBUTING.md](CONTRIBUTING.md)
89 | - open an [issue](https://github.com/neulab/prompt2model/issues) or submit a PR
90 | - join us on [discord](https://discord.gg/UCy9csEmFc)
91 | - or reach out to [@vijaytarian](https://twitter.com/vijaytarian)
92 | and [@Chenan3_Zhao](https://twitter.com/Chenan3_Zhao) on Twitter
93 |
94 | ## Cite
95 |
96 | We have [written a paper describing Prompt2Model in detail](https://arxiv.org/abs/2308.12261).
97 |
98 | If you use Prompt2Model in your research, please cite us!
99 |
100 | If you discuss or use the overall prompt2model framework, please reference
101 |
102 | ```bibtex
103 | @misc{prompt2model,
104 | title={Prompt2Model: Generating Deployable Models from Natural Language Instructions},
105 | author={Vijay Viswanathan and Chenyang Zhao and Amanda Bertsch and Tongshuang Wu and Graham Neubig},
106 | year={2023},
107 | eprint={2308.12261},
108 | archivePrefix={arXiv},
109 | primaryClass={cs.CL}
110 | }
111 | ```
112 |
113 | If you discuss or use our dataset retrieval and transformation tools, please reference
114 |
115 | ```bibtex
116 | @misc{prompt2modeldatatune,
117 | title={Better Synthetic Data by Retrieving and Transforming Existing Datasets},
118 | author={Saumya Gandhi and Ritu Gala and Vijay Viswanathan and Tongshuang Wu and Graham Neubig},
119 | year={2024},
120 | eprint={2404.14361},
121 | archivePrefix={arXiv},
122 | primaryClass={cs.CL}
123 | }
124 | ```
125 |
--------------------------------------------------------------------------------
/examples/create_synthetic_data_example.py:
--------------------------------------------------------------------------------
1 | """Example to demonstrate how to create synthetic data based on prompt."""
2 |
3 | import prompt2model.utils.api_tools as api_tools
4 | from prompt2model.dataset_generator.base import DatasetSplit
5 | from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator
6 | from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
7 | from prompt2model.utils.api_tools import APIAgent
8 |
9 | if __name__ == "__main__":
10 | # set API keys and create default API agent.
11 | api_tools.default_api_agent = APIAgent(
12 | model_name="gpt-3.5-turbo-16k", max_tokens=8000
13 | )
14 |
15 | # create prompt based on which transform data will be created
16 | prompt = """
17 | Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.
18 |
19 | Here are examples with input questions and context passages, along with their expected outputs:
20 |
21 | input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50."
22 | output="Santa Clara"
23 |
24 | input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)."
25 | output="Vistula River"
26 |
27 | input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries."
28 | output="Europe"
29 | """ # noqa: E501
30 | # parse the prompt to get the instruction and examples
31 | prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
32 | prompt_spec.parse_from_prompt(prompt)
33 | print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}")
34 |
35 | # set hyperparams
36 | initial_temperature = 0.4
37 | max_temperature = 1.4
38 | num_samples_total = 20
39 |
40 | # run this pipeline to generate data synthetically based on prompt
41 | unlimited_dataset_generator = PromptBasedDatasetGenerator(
42 | initial_temperature=initial_temperature,
43 | max_temperature=max_temperature,
44 | responses_per_request=3,
45 | )
46 | generated_dataset = unlimited_dataset_generator.generate_dataset_split(
47 | prompt_spec, num_samples_total, split=DatasetSplit.TRAIN
48 | )
49 |
50 | # save the final generated dataset to disk
51 | generated_dataset.save_to_disk("demo_generated_dataset")
52 |
--------------------------------------------------------------------------------
/examples/create_transform_data_example.py:
--------------------------------------------------------------------------------
1 | """Example of how to create transform data based on a prompt."""
2 |
3 | import prompt2model.utils.api_tools as api_tools
4 | from prompt2model.dataset_retriever import DescriptionDatasetRetriever
5 | from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
6 | from prompt2model.utils.api_tools import APIAgent
7 |
8 | if __name__ == "__main__":
9 | # set API keys and create default API agent.
10 | api_tools.default_api_agent = APIAgent(
11 | model_name="gpt-3.5-turbo-16k", max_tokens=8000
12 | )
13 |
14 | # create prompt based on which transform data will be created
15 | prompt = """
16 | Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.
17 |
18 | Here are examples with input questions and context passages, along with their expected outputs:
19 |
20 | input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50."
21 | output="Santa Clara"
22 |
23 | input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)."
24 | output="Vistula River"
25 |
26 | input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries."
27 | output="Europe"
28 | """ # noqa: E501
29 | # parse the prompt to get the instruction and examples
30 | prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
31 | prompt_spec.parse_from_prompt(prompt)
32 | print(f"Instruction: {prompt_spec.instruction}\nExamples: {prompt_spec.examples}")
33 |
34 | # run this pipeline to retrieve relevant datasets, rerank them,
35 | # and transform them based on the prompt
36 | total_num_points_to_transform = 20
37 | retriever = DescriptionDatasetRetriever(
38 | auto_transform_data=True,
39 | total_num_points_to_transform=total_num_points_to_transform,
40 | )
41 | retrieved_dataset_dict = retriever.retrieve_dataset_dict(
42 | prompt_spec,
43 | )
44 |
45 | # save the final dataset to disk
46 | if retrieved_dataset_dict is not None:
47 | retrieved_dataset_dict.save_to_disk("demo_retrieved_dataset_dict")
48 |
--------------------------------------------------------------------------------
/examples/mistral_qlora_finetune_example.py:
--------------------------------------------------------------------------------
1 | """Example of how to fine-tune a model using the QLoRATrainer class."""
2 |
3 | import os
4 |
5 | from datasets import load_from_disk
6 |
7 | from prompt2model.model_trainer.qlora_trainer import QLoRATrainer
8 | from prompt2model.utils.dataset_utils import format_train_data, make_combined_datasets
9 |
10 | if __name__ == "__main__":
11 | # First, we load in the datasets we want to fine-tune on.
12 | retrieved_dataset_dict = load_from_disk("demo_retrieved_dataset_dict")
13 | retrieved_dataset = retrieved_dataset_dict["train"]
14 | generated_dataset = load_from_disk("demo_generated_dataset")
15 | dataset_list = [retrieved_dataset, generated_dataset]
16 |
17 | # Next, we combine datasets and create train and eval splits.
18 | train_dataset = make_combined_datasets(dataset_list)
19 | splits = train_dataset.train_test_split(test_size=0.1)
20 | train_dataset = splits["train"]
21 | eval_dataset = splits["test"]
22 |
23 | # At this point, both train_dataset and eval_dataset are datasets with two
24 | # columns: "input_col" and "output_col".
25 | # We need to format them into a single column, "text", for the QLoRATrainer to use.
26 | formatted_train_dataset = format_train_data(train_dataset)
27 | formatted_eval_dataset = format_train_data(eval_dataset)
28 |
29 | # Next, we define the hyperparameters for the QLoRATrainer.
30 | num_epochs = 1
31 | qlora_alpha = 8
32 | qlora_r = 16
33 | qlora_lr = 1e-5
34 | save_folder_path = "qlora_finetuned_model"
35 | load_best_model_at_end = False
36 |
37 | # Next, we create a QLoRATrainer object and call the train_model method.
38 | trainer = QLoRATrainer(model_name="mistralai/Mistral-7B-v0.1", model_max_length=512)
39 |
40 | # `formatted_eval_dataset` contains just one column: "text",
41 | # and is used to calculate eval loss, by checking loss for each next token.
42 | # `eval_dataset` contains two columns: "input_col" and "output_col",
43 | # and is used to calculate eval accuracy, by checking whether the generated output
44 | # exactly matches the expected output.
45 | trained_model, trained_tokenizer = trainer.train_model(
46 | formatted_train_dataset,
47 | formatted_eval_dataset,
48 | eval_dataset,
49 | num_epochs=1,
50 | alpha=qlora_alpha,
51 | r=qlora_r,
52 | lr=qlora_lr,
53 | save_folder_path=save_folder_path,
54 | load_best_model_at_end=load_best_model_at_end,
55 | )
56 |
57 | # Finally, we save the trained model and tokenizer to disk.
58 | trained_model.save_pretrained(os.path.join(save_folder_path, "demo_final_model"))
59 | trained_tokenizer.save_pretrained(
60 | os.path.join(save_folder_path, "demo_final_tokenizer")
61 | )
62 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | ignore_missing_imports = True
3 |
4 | [mypy-yaml.*]
5 | ignore_missing_imports = True
6 |
7 | [mypy-termcolor.*]
8 | ignore_missing_imports = True
9 |
--------------------------------------------------------------------------------
/prompt2model/__init__.py:
--------------------------------------------------------------------------------
1 | """prompt2model top level package."""
2 |
--------------------------------------------------------------------------------
/prompt2model/dataset_generator/__init__.py:
--------------------------------------------------------------------------------
1 | """Import DatasetGenerator classes."""
2 | from prompt2model.dataset_generator.base import DatasetGenerator, DatasetSplit
3 | from prompt2model.dataset_generator.mock import MockDatasetGenerator
4 | from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator
5 |
6 | __all__ = (
7 | "PromptBasedDatasetGenerator",
8 | "MockDatasetGenerator",
9 | "DatasetGenerator",
10 | "DatasetSplit",
11 | )
12 |
--------------------------------------------------------------------------------
/prompt2model/dataset_generator/base.py:
--------------------------------------------------------------------------------
1 | """An interface for dataset generation."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | from abc import ABC, abstractmethod
6 | from enum import Enum
7 |
8 | import datasets
9 |
10 | from prompt2model.prompt_parser import PromptSpec
11 |
12 |
13 | class DatasetSplit(Enum):
14 | """The split of a dataset."""
15 |
16 | TRAIN = "train"
17 | VAL = "val"
18 | TEST = "test"
19 |
20 |
21 | class DatasetGenerator(ABC):
22 | """A class for generating datasets from a prompt specification."""
23 |
24 | @abstractmethod
25 | def generate_dataset_split(
26 | self,
27 | prompt_spec: PromptSpec,
28 | num_examples: int,
29 | split: DatasetSplit,
30 | ) -> datasets.Dataset:
31 | """Generate data for a single named split of data.
32 |
33 | Args:
34 | prompt_spec: A prompt spec (containing a system description).
35 | num_examples: Expected number of examples in split.
36 | split: Name of dataset split to generate.
37 |
38 | Returns:
39 | A single dataset split.
40 | """
41 |
42 | def generate_dataset_dict(
43 | self,
44 | prompt_spec: PromptSpec,
45 | num_examples: dict[DatasetSplit, int],
46 | ) -> datasets.DatasetDict:
47 | """Generate full dataset splits (e.g. train/dev/test) from a prompt.
48 |
49 | Args:
50 | prompt_spec: A prompt specification.
51 | num_examples: Expected number of
52 | examples per split (train/val/test).
53 |
54 | Returns:
55 | A DatasetDict containing train, val, and test splits.
56 | """
57 | dataset_dict = datasets.DatasetDict(
58 | {
59 | split.value: self.generate_dataset_split(prompt_spec, num, split=split)
60 | for split, num in num_examples.items()
61 | }
62 | )
63 |
64 | return dataset_dict
65 |
--------------------------------------------------------------------------------
/prompt2model/dataset_generator/mock.py:
--------------------------------------------------------------------------------
1 | """A class for generating empty datasets (for testing purposes)."""
2 |
3 | import datasets
4 |
5 | from prompt2model.dataset_generator.base import DatasetGenerator, DatasetSplit
6 | from prompt2model.prompt_parser import PromptSpec
7 |
8 |
9 | class MockDatasetGenerator(DatasetGenerator):
10 | """A class for generating empty datasets (for testing purposes)."""
11 |
12 | def generate_dataset_split(
13 | self,
14 | prompt_spec: PromptSpec,
15 | expected_num_examples: int,
16 | split: DatasetSplit,
17 | ) -> datasets.Dataset:
18 | """Create empty versions of the datasets, for testing.
19 |
20 | Args:
21 | prompt_spec: A prompt specification.
22 | expected_num_examples: Number of examples in split.
23 | split: Name of dataset split to generate.
24 |
25 | Returns:
26 | A single dataset split.
27 | """
28 | _ = prompt_spec, split # suppress unused variable warnings
29 | col_values = [""] * expected_num_examples
30 | # Construct empty-valued dataframe with length matching expected_num_examples.
31 | return datasets.Dataset.from_dict(
32 | {"input_col": col_values, "output_col": col_values}
33 | )
34 |
--------------------------------------------------------------------------------
/prompt2model/dataset_generator/readme.md:
--------------------------------------------------------------------------------
1 | # Dataset Generator
2 |
3 | ## Overview
4 |
5 | - `DatasetGenerator`: An abstract class to generate datasets.
6 | - `DatasetSplit`: An enumeration class defining dataset types (`TRAIN`,
7 | `VALIDATION`, `TEST`).
8 | - `PromptBasedDatasetGenerator`: A concrete class
9 | for dataset generation using GPT-3.5 API.
10 |
11 | ## Getting Started
12 |
13 | - **Import the Modules**:
14 |
15 | ```python
16 | from prompt2model.dataset_generator import PromptBasedDatasetGenerator, DatasetSplit
17 | from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
18 | ```
19 |
20 | - **Setup API Key**:
21 |
22 | Set an API key as an environment variable. For instance, if using OpenAI:
23 |
24 | ```bash
25 | export OPENAI_API_KEY=""
26 | ```
27 |
28 | - **Parse the Prompt**:
29 |
30 | ```python
31 | prompt_spec = PromptBasedInstructionParser(task_type=TaskType.)
32 | # Refer the document string of DatasetSplit for more details.
33 | prompt = ""
34 | prompt_spec.parse_from_prompt(prompt)
35 | ```
36 |
37 | Or you can mock a `PromptSpec` object, as
38 | shown in [PromptParser](./../prompt_parser/readme.md).
39 |
40 | **Generate Dataset**:
41 |
42 | For a specific split:
43 |
44 | ```python
45 | expected_num_examples = 100
46 | split = DatasetSplit.TRAIN
47 | dataset = dataset_generator.generate_dataset_split(
48 | prompt_spec
49 | , expected_num_examples
50 | , split
51 | )
52 | ```
53 |
54 | Or, for multiple splits:
55 |
56 | ```python
57 | expected_num_examples = {
58 | DatasetSplit.TRAIN: 1000,
59 | DatasetSplit.VALIDATION: 100,
60 | DatasetSplit.TEST: 200
61 | }
62 | dataset_dict = dataset_generator.generate_dataset_dict(prompt_spec, expected_num_examples)
63 | ```
64 |
--------------------------------------------------------------------------------
/prompt2model/dataset_processor/__init__.py:
--------------------------------------------------------------------------------
1 | """Import DatasetProcessor classes."""
2 | from prompt2model.dataset_processor.base import BaseProcessor
3 | from prompt2model.dataset_processor.mock import MockProcessor
4 | from prompt2model.dataset_processor.textualize import TextualizeProcessor
5 |
6 | __all__ = ("BaseProcessor", "TextualizeProcessor", "MockProcessor")
7 |
--------------------------------------------------------------------------------
/prompt2model/dataset_processor/mock.py:
--------------------------------------------------------------------------------
1 | """A mock dataset processor for testing purposes."""
2 | from __future__ import annotations
3 |
4 | import datasets
5 |
6 | from prompt2model.dataset_processor.base import BaseProcessor
7 |
8 |
9 | class MockProcessor(BaseProcessor):
10 | """A class for retrieving datasets."""
11 |
12 | def process_dataset_dict(
13 | self, instruction: str, dataset_dicts: list[datasets.DatasetDict]
14 | ) -> list[datasets.DatasetDict]:
15 | """A mock function to post-process a list of DatasetDicts.
16 |
17 | Args:
18 | instruction: The instruction used as a prefix to explain the task.
19 | dataset_dicts: A list of DatasetDicts (generated or retrieved).
20 |
21 | Returns:
22 | A list of DatasetDicts, all examples are converted into text2text fashion.
23 | """
24 | _ = instruction
25 | return dataset_dicts
26 |
27 | @staticmethod
28 | def _post_process_example(
29 | example: dict,
30 | instruction: str,
31 | task_id: int,
32 | has_encoder: bool,
33 | dataset_split: str,
34 | eos_token: str,
35 | ) -> dict:
36 | """A mock function that modifies a given example dictionary.
37 |
38 | Args:
39 | example: A dictionary representing an example.
40 | instruction: The instruction used as a prefix to explain the task.
41 | task_id: A tag marking which dataset (from dataset_dicts) this example
42 | comes from. Used for multi-task training.
43 | has_encoder: Whether the retrieved model has an encoder.
44 | dataset_split: The split of the example, i.e. train/val/test.
45 | eos_token: The end-of-sentence token of the tokenizer.
46 |
47 | Returns:
48 | A dictionary with `model_input` as the input to models
49 | and `model_output` as the expected output of models.
50 | """
51 | _ = instruction, task_id
52 | example["model_input"] = example["input_col"]
53 | example["model_output"] = example["output_col"]
54 | return example
55 |
--------------------------------------------------------------------------------
/prompt2model/dataset_processor/readme.md:
--------------------------------------------------------------------------------
1 | # Dataset Processor
2 |
3 | ## Overview
4 |
5 | - `BaseProcessor`: The base class for dataset post-processing.
6 | - `TextualizeProcessor`: Transforms all datasets into a consistent
7 | text-to-text generation format.
8 |
9 | ## Getting Started
10 |
11 | - **Import the Module**:
12 |
13 | ```python
14 | from prompt2model.dataset_processor.textualize import TextualizeProcessor
15 | ```
16 |
17 | - **Initialize TextualizeProcessor**:
18 |
19 | ```python
20 | processor = TextualizeProcessor(has_encoder=)
21 | # : Whether the model you want to finetune has an encoder.
22 | ```
23 |
24 | Choose encoder type:
25 |
26 | - `has_encoder=True` for encoder-decoder models (e.g., T5).
27 | - `has_encoder=False` for decoder-only/autoregressive models (e.g., GPT2).
28 |
29 | - **Process Datasets**:
30 |
31 | ```python
32 | instruction = ""
33 | dataset_dicts = [...] # List of DatasetDict
34 | modified_dataset_dicts = processor.process_dataset_dict(instruction, dataset_dicts)
35 | ```
36 |
--------------------------------------------------------------------------------
/prompt2model/dataset_processor/textualize.py:
--------------------------------------------------------------------------------
1 | """A dataset processor to convert datasets into Text2Text fashion."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | from prompt2model.dataset_processor.base import BaseProcessor
6 | from prompt2model.utils.logging_utils import get_formatted_logger
7 |
8 | logger = get_formatted_logger("DatasetProcessor")
9 |
10 |
11 | class TextualizeProcessor(BaseProcessor):
12 | """A class for pre-processing datasets before training."""
13 |
14 | def __init__(self, has_encoder: bool, eos_token: str | None = None) -> None:
15 | """Initialize the `TextualizeProcessor`.
16 |
17 | Args:
18 | has_encoder: Whether the retrieved model has an encoder.
19 | Encoder-decoder model like T5 has two model inputs.
20 | Decoder-only model like GPT only has one model input, thus
21 | `model_input` should be added with the `output_col`.
22 | eos_token: The end-of-sentence token of the tokenizer.
23 | The T5 tokenizer automatically adds eos token in the end of
24 | sequence. So only TextualizeProcessor for GPT model
25 | requires eos_token.
26 | """
27 | super().__init__(has_encoder, eos_token)
28 | if has_encoder and eos_token is not None:
29 | logger.info(
30 | (
31 | "The T5 tokenizer automatically adds eos token in the end of sequence when tokenizing." # noqa E501
32 | " So the eos_token of encoder-decoder model tokenizer is unnecessary." # noqa E501
33 | )
34 | )
35 | elif not has_encoder and eos_token is None:
36 | logger.warning(
37 | (
38 | "The autoregressive model tokenizer does not automatically add eos token in the end of the sequence." # noqa E501
39 | " So the `eos_token` of the autoregressive model is required." # noqa E501
40 | )
41 | )
42 |
43 | @staticmethod
44 | def _post_process_example(
45 | example: dict,
46 | instruction: str,
47 | task_id: int,
48 | has_encoder: bool,
49 | dataset_split: str,
50 | eos_token: str | None = None,
51 | ) -> dict:
52 | """Modifies the input column of a given example dictionary.
53 |
54 | Args:
55 | example: A dictionary representing an example.
56 | instruction: The instruction used as a prefix to explain the task.
57 | task_id: A tag marking which dataset (from dataset_dicts) this example
58 | comes from. Used for multi-task training.
59 | has_encoder: Whether the retrieved model has an encoder.
60 | dataset_split: The split of the example, i.e. train/val/test.
61 | eos_token: The end-of-sentence token of the tokenizer.
62 |
63 | Returns:
64 | A dictionary with `model_input` as the input to models
65 | and `model_output` as the expected output of models.
66 | """
67 | if dataset_split not in (
68 | "train",
69 | "val",
70 | "test",
71 | ):
72 | raise ValueError("Datset split must be in train/val/test.")
73 | example["output_col"] = str(example["output_col"])
74 | if has_encoder:
75 | model_input = f"{instruction}\nExample:\n{example['input_col']}\nLabel:\n" # noqa E501
76 | model_output = example["output_col"]
77 | else:
78 | # The T5 tokenizer automatically adds eos token in `add_eos_if_not_present`.
79 | # On the contrary, model_output of GPT model need eos token in the end.
80 | if dataset_split == "train":
81 | model_output = example["output_col"] + eos_token
82 | model_input = f"{instruction}\nExample:\n{example['input_col']}\nLabel:\n{model_output}" # noqa E501
83 | else:
84 | # The val/test split is only used for evaluation. Since our decode
85 | # method in the ModelExecutor set `skip_special_tokens=True`,
86 | # we do not need to add eos token in the end.
87 | model_output = example["output_col"]
88 | model_input = f"{instruction}\nExample:\n{example['input_col']}\nLabel:\n" # noqa E501
89 | example["model_input"] = model_input
90 | example["model_output"] = model_output
91 | return example
92 |
--------------------------------------------------------------------------------
/prompt2model/dataset_retriever/__init__.py:
--------------------------------------------------------------------------------
1 | """Import DatasetRetriever classes."""
2 | from prompt2model.dataset_retriever.base import DatasetRetriever
3 | from prompt2model.dataset_retriever.description_dataset_retriever import (
4 | DatasetInfo,
5 | DescriptionDatasetRetriever,
6 | )
7 | from prompt2model.dataset_retriever.mock import MockRetriever
8 |
9 | __all__ = (
10 | "DatasetRetriever",
11 | "MockRetriever",
12 | "DescriptionDatasetRetriever",
13 | "DatasetInfo",
14 | )
15 |
--------------------------------------------------------------------------------
/prompt2model/dataset_retriever/base.py:
--------------------------------------------------------------------------------
1 | """An interface for dataset retrieval."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | import dataclasses
6 | from abc import ABC, abstractmethod
7 |
8 | import datasets
9 |
10 | from prompt2model.prompt_parser import PromptSpec
11 |
12 |
13 | @dataclasses.dataclass
14 | class DatasetInfo:
15 | """Store the dataset name, description, and query-dataset score for each dataset.
16 |
17 | Args:
18 | name: The name of the dataset.
19 | description: The description of the dataset.
20 | score: The retrieval score of the dataset.
21 | """
22 |
23 | name: str
24 | description: str
25 | score: float
26 |
27 |
28 | # pylint: disable=too-few-public-methods
29 | class DatasetRetriever(ABC):
30 | """A class for retrieving datasets."""
31 |
32 | @abstractmethod
33 | def retrieve_dataset_dict(
34 | self, prompt_spec: PromptSpec
35 | ) -> datasets.DatasetDict | None:
36 | """Retrieve full dataset splits (e.g. train/dev/test) from a prompt.
37 |
38 | Args:
39 | prompt_spec: A prompt spec (containing a system description).
40 |
41 | Returns:
42 | A retrieved DatasetDict containing train/val/test splits.
43 | """
44 |
--------------------------------------------------------------------------------
/prompt2model/dataset_retriever/column_selection_prompt.py:
--------------------------------------------------------------------------------
1 | """Utilities to construct an LLM "metaprompt" for our column selection."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | METAPROMPT_BASE = """Your objective is to carefully analyze the task and the dataset mentioned, and decide whether the columns are relevant input, relevant output, irrelevant for the given task, or if it is ambiguous. There should be at most one output column. It is possible to have no relevant columns, in which case return the input and output column as empty lists. Answer in a json format, with the following keys: input, output, irrelevant, ambiguous""" # noqa: E501
6 | METAPROMPT_EXAMPLES = [
7 | (
8 | """You are tasked with the following process. In this task, you will generate summaries for given texts. For this task, you will use the Scientific Papers dataset from HuggingFace. Dataset_description: Scientific papers datasets contains two sets of long and structured documents. The datasets are obtained from ArXiv and PubMed OpenAccess repositories.
9 | A sample data instance from this dataset is as follows.
10 | {
11 | "abstract": "\" we have studied the leptonic decay @xmath0 , via the decay channel @xmath1 , using a sample of tagged @xmath2 decays collected...",
12 | "article": "\"the leptonic decays of a charged pseudoscalar meson @xmath7 are processes of the type @xmath8 , where @xmath9 , @xmath10 , or @...",
13 | "section_names": "[sec:introduction]introduction\n[sec:detector]data and the cleo- detector\n[sec:analysys]analysis method\n[sec:conclusion]summary"
14 | }
15 | This dataset has the following columns: [abstract, article, section_names].
16 | """, # noqa: E501
17 | """
18 | {
19 | "input": ["article"],
20 | "output": ["abstract"],
21 | "irrelevant": ["section_names"],
22 | "ambiguous": []
23 | }""",
24 | ),
25 | (
26 | """
27 | You are tasked with the following process. In this task, you will detect whether some given text uses hateful speech or not. For this task you will use the hate_speech_offensive dataset from HuggingFace. Dataset_description: An annotated dataset for hate speech and offensive language detection on tweets.
28 | A sample data instance from this is as follows:
29 | {
30 | "count": 3,
31 | "hate_speech_count": 0,
32 | "offensive_language_count": 0,
33 | "neither_count": 3,
34 | "label": 2, # "neither"
35 | "tweet": "!!! RT @mayasolovely: As a woman you shouldn't complain about cleaning up your house. & as a man you should always take the trash out...")
36 | }.
37 | This dataset has the following columns: [count, hate_speech_count, offensive_language_count, neither_count, class, tweet]""", # noqa: E501
38 | """
39 | {
40 | "input": ["tweet"],
41 | "output": ["label"],
42 | "irrelevant": [],
43 | "ambiguous": ["hate_speech_count", "offensive_language_count", "neither_count", "count"]
44 | }""", # noqa: E501
45 | ),
46 | (
47 | """You are tasked with the following process. Your job is to be able to translate between languages. For this task, you will use a custom dataset. Dataset_description: This dataset is meant to translate between languages.
48 | A sample data instance from this is as follows:
49 | {
50 | "translation": ["ca: "El department de bombers té el seu propi equip d'investigació.", "en": "Well, the fire department has its own investigative unit."]
51 | }
52 | This dataset has the following columns: [translation]. """, # noqa: E501
53 | """
54 | {
55 | "input": [],
56 | "output": [],
57 | "irrelevant": []
58 | "ambiguous": ["translation"]
59 | }""",
60 | ),
61 | (
62 | """You are tasked with the following process. Your job is to be able to summarize a given text. For this task, you will use the math_qa dataset from HuggingFace. Dataset_description: Our dataset is gathered by using a new representation language to annotate over the AQuA-RAT dataset with fully-specified operational programs.
63 | A sample data instance from this is as follows:
64 | {
65 | "Problem": "a multiple choice test consists of 4 questions , and each question has 5 answer choices . in how many r ways can the test be completed if every question is unanswered ?",
66 | "Rationale": "\"5 choices for each of the 4 questions , thus total r of 5 * 5 * 5 * 5 = 5 ^ 4 = 625 ways to answer all of them . answer : c .\"",
67 | "annotated_formula": "power(5, 4)",
68 | "category": "general",
69 | "correct": "c",
70 | "linear_formula": "power(n1,n0)|",
71 | "options": "a ) 24 , b ) 120 , c ) 625 , d ) 720 , e ) 1024"
72 | }
73 | This dataset has the following columns: [problem, rationale, options, correct, annotated_formula]. """, # noqa: E501
74 | """
75 | {
76 | "input": [],
77 | "output": [],
78 | "irrelevant": ["problem", "rationale", "options", "correct", "annotated_formula"],
79 | "ambiguous": []
80 | }""", # noqa: E501
81 | ),
82 | ]
83 |
84 | INPUT_PROMPT_TEMPLATE = """You are tasked with the following process. {instruction} For this task, you will use the {dataset_name} dataset from HuggingFace. Dataset Description: {dataset_description} \nA sample data instance from this is as follows. {sample_row}.\nThis dataset has the following columns: [{dataset_columns} ].""" # noqa: E501
85 | SINGLE_DEMONSTRATION_TEMPLATE = (
86 | 'Task: """\n{prompt}\n"""\n\nRequired Columns :\n{columns}'
87 | )
88 | ENDING_LINE = "After seeing these examples with the required columns, please provide the relevant columns for this context:" # noqa: E501
89 |
90 |
91 | def build_input(
92 | instruction: str,
93 | dataset_name: str,
94 | dataset_description: str,
95 | dataset_columns: str,
96 | sample_row: dict,
97 | ) -> str:
98 | """Template function to build input based on arguments."""
99 | input_prompt = INPUT_PROMPT_TEMPLATE.format(
100 | instruction=instruction,
101 | dataset_name=dataset_name,
102 | dataset_description=dataset_description,
103 | dataset_columns=dataset_columns,
104 | sample_row=sample_row,
105 | )
106 | input_prompt = SINGLE_DEMONSTRATION_TEMPLATE.format(
107 | prompt=input_prompt, columns=""
108 | ) # columns="" because that is what we are trying to predict
109 | return input_prompt
110 |
111 |
112 | def construct_prompt_for_column_selection(
113 | instruction: str,
114 | dataset_name: str,
115 | dataset_description: str,
116 | dataset_columns: str,
117 | sample_row: dict,
118 | ) -> str:
119 | """Generate prompt for column selection."""
120 | prompt_sections = [METAPROMPT_BASE]
121 | for prompt, columns in METAPROMPT_EXAMPLES:
122 | prompt_sections.append(
123 | SINGLE_DEMONSTRATION_TEMPLATE.format(prompt=prompt, columns=columns)
124 | )
125 | all_prompts = "\n\n------\n\n".join(prompt_sections) + "\n\n------\n\n"
126 | input_prompt = build_input(
127 | instruction, dataset_name, dataset_description, dataset_columns, sample_row
128 | )
129 | all_prompts += ENDING_LINE + input_prompt
130 |
131 | return all_prompts
132 |
--------------------------------------------------------------------------------
/prompt2model/dataset_retriever/mock.py:
--------------------------------------------------------------------------------
1 | """A mock dataset retriever for testing purposes."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | import datasets
6 |
7 | from prompt2model.dataset_retriever.base import DatasetRetriever
8 | from prompt2model.prompt_parser import PromptSpec
9 |
10 |
11 | class MockRetriever(DatasetRetriever):
12 | """A class for retrieving datasets."""
13 |
14 | def __init__(self):
15 | """Construct a mock dataset retriever."""
16 |
17 | def retrieve_dataset_dict(
18 | self, prompt_spec: PromptSpec
19 | ) -> datasets.DatasetDict | None:
20 | """Return a single empty DatasetDict for testing purposes."""
21 | _ = prompt_spec # suppress unused vaiable warning
22 | mock_dataset = datasets.Dataset.from_dict(
23 | {"input_col": [""], "output_col": [""]}
24 | )
25 | return [
26 | datasets.DatasetDict(
27 | {"train": mock_dataset, "val": mock_dataset, "test": mock_dataset}
28 | )
29 | ]
30 |
--------------------------------------------------------------------------------
/prompt2model/dataset_retriever/readme.md:
--------------------------------------------------------------------------------
1 | # Dataset Retriever
2 |
3 | ## Overview
4 |
5 | - `DatasetRetriever`: Interface for retrieving datasets based on a
6 | prompt.
7 | - `DescriptionDatasetRetriever`: Retrieves HuggingFace datasets using
8 | similarity to a given prompt.
9 |
10 | ## Getting Started
11 |
12 | - Import Modules
13 |
14 | ```python
15 | from prompt2model.dataset_retriever import DescriptionDatasetRetriever
16 | from prompt2model.prompt_parser import MockPromptSpec, TaskType
17 | ```
18 |
19 | - Initialize Retriever
20 |
21 | ```python
22 | retriever = DescriptionDatasetRetriever()
23 | ```
24 |
25 | Various parameters like search index path, model name, and search
26 | depth can be customized during initialization.
27 |
28 | - Prepare the Prompt
29 |
30 | ```python
31 | task_type = TaskType.TEXT_GENERATION
32 | prompt_text = "..."
33 | prompt_spec = MockPromptSpec(task_type)
34 | prompt_spec._instruction = prompt_text
35 | ```
36 |
37 | - Retrieve Dataset
38 |
39 | ```python
40 | dataset_dict = retriever.retrieve_dataset_dict(
41 | prompt_spec, blocklist=[]
42 | )
43 | ```
44 |
45 | `dataset_dict` will contain the dataset splits (train/val/test) most
46 | relevant to the given prompt.
47 |
--------------------------------------------------------------------------------
/prompt2model/dataset_retriever/reranking_prompt.py:
--------------------------------------------------------------------------------
1 | """This module contains the functions to generate the prompt for dataset reranking."""
2 | from __future__ import annotations # noqa FI58
3 |
4 | METAPROMPT_BASE_DATASET = """Your objective is to choose the most relevant dataset for a given a task (and few examples of the task). For each dataset, you will be provided with the dataset description, and tags related to the dataset which provide meta-information about the dataset. Please return the most relevant dataset, e.g. squad """ # noqa: E501
5 |
6 | METAPROMPT_BASE_CONFIG = """Your objective is to choose the most relevant config of a dataset for a given a task (and few examples of the task). A config of a dataset is a version of that dataset. You will be provided information about this dataset, followed by information about its configs. For each config, you will be provided with the config name, and columns and rows of that config. The columns of the config could be useful at understanding whether this config is relevant to the given task. Another relevant factor is the config_name, this would give information on a high level about what each config represents. Please return the most relevant config""" # noqa: E501
7 |
8 | INPUT_PROMPT_DATASET_TEMPLATE = """The following is the task \n {instruction} \n and these are some examples of the same: \n{examples} \n
9 | There are {num} datasets available for this task. \n
10 | {datasets}.
11 | The name of the most relevant dataset for this task is:""" # noqa: E501
12 |
13 | INPUT_PROMPT_CONFIG_TEMPLATE = """ The following is the task: \n {instruction} \n
14 | The following is the dataset selected: {dataset_name}: {dataset_description}
15 | There are {num} configs available in this dataset for this task. \n
16 | {configs}
17 | The name of the most relevant config from these for this task is:
18 | """
19 |
20 | DATASET_TEMPLATE = """[{counter}] **{dataset_name}**:\nDescription-{dataset_description}.\nThis dataset has the following tags:\n {tags} """ # noqa: E501
21 | CONFIG_TEMPLATE = """\t[{counter}] **{config_name}**\n: The columns in this config are {dataset_columns}.\n An example row from this config is {sample_row}.\n """ # noqa: E501
22 |
23 |
24 | def build_datasets_prompt(instruction: str, examples: str, datasets_infos: dict):
25 | """Builds the prompt for dataset reranking.
26 |
27 | Args:
28 | instruction (str): Task instructions
29 | examples (str): Task Examples
30 | datasets_infos (dict): A dictionary containing information about all datasets.
31 |
32 | Returns:
33 | str: The input prompt for dataset retrieval.
34 | """
35 | dataset_string = ""
36 | for i, (dataset_name, dataset_info) in enumerate(datasets_infos.items(), start=1):
37 | dataset_string += f"""{DATASET_TEMPLATE.format(
38 | counter = i,
39 | dataset_name=dataset_name,
40 | dataset_description=dataset_info["description"],
41 | tags = dataset_info["tags"]
42 | )}\n\n"""
43 |
44 | input_prompt = INPUT_PROMPT_DATASET_TEMPLATE.format(
45 | instruction=instruction,
46 | examples=examples,
47 | datasets=dataset_string,
48 | num=len(datasets_infos),
49 | )
50 | return input_prompt
51 |
52 |
53 | def build_configs_prompt(instruction: str, examples: str, dataset_info: dict):
54 | """Builds the prompt for config reranking.
55 |
56 | Args:
57 | instruction (str): Task instructions
58 | examples (str): Task Examples
59 | datasets_infos (dict): A dictionary containing information about
60 | the specific dataset, which includes config information.
61 |
62 | Returns:
63 | str: The input prompt for dataset retrieval.
64 | """
65 | configs_string = ""
66 | for j, (config_name, config_info) in enumerate(dataset_info["configs"].items()):
67 | configs_string += f"""{CONFIG_TEMPLATE.format(
68 | counter = chr(ord('a')+j),
69 | config_name = config_name,
70 | dataset_columns = config_info["columns"],
71 | sample_row = config_info["sample_row"]
72 | )}\n""" # noqa: E501
73 |
74 | input_prompt = INPUT_PROMPT_CONFIG_TEMPLATE.format(
75 | instruction=instruction,
76 | examples=examples,
77 | dataset_name=dataset_info["dataset_name"],
78 | dataset_description=dataset_info["description"],
79 | configs=configs_string,
80 | num=len(dataset_info["configs"]),
81 | )
82 | return input_prompt
83 |
84 |
85 | def construct_prompt_for_dataset_reranking(
86 | instruction: str,
87 | examples: str,
88 | datasets_infos: dict,
89 | is_config: bool = False,
90 | ):
91 | """Generate the full prompt for dataset reranking based on the given parameters.
92 |
93 | Args:
94 | instruction (str): Instruction of the task.
95 | examples (str): Examples of the task.
96 | datasets_infos (dict): Dictionary with dataset/config information. Each
97 | dataset_info object also has a configs object
98 | representing the various configs of that dataset
99 | is_config (bool): bool: Whether the prompt is for dataset
100 | reranking or config reranking
101 |
102 | Returns:
103 | str: Builds a comprehensive prompt for dataset reranking. This prompt includes
104 | the base instructions, incontext example and the prompt returned by the
105 | build_input function.
106 | """
107 | if is_config:
108 | metaprompt_base = METAPROMPT_BASE_CONFIG
109 | input_prompt = build_configs_prompt(instruction, examples, datasets_infos)
110 | else:
111 | metaprompt_base = METAPROMPT_BASE_DATASET
112 | input_prompt = build_datasets_prompt(instruction, examples, datasets_infos)
113 |
114 | prompt_sections = [metaprompt_base]
115 | all_prompts = "\n\n------\n\n".join(prompt_sections) + "\n\n------\n\n"
116 | all_prompts += input_prompt
117 |
118 | return all_prompts
119 |
--------------------------------------------------------------------------------
/prompt2model/dataset_retriever/run_dataset_retriever.py:
--------------------------------------------------------------------------------
1 | """Script to run the dataset retriever in isolation."""
2 | from prompt2model.dataset_retriever import DescriptionDatasetRetriever
3 | from prompt2model.prompt_parser import MockPromptSpec, TaskType
4 |
5 | if __name__ == "__main__":
6 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION)
7 | prompt = """Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.""" # noqa E501
8 | prompt_spec._instruction = prompt
9 |
10 | retriever = DescriptionDatasetRetriever()
11 | retriever.retrieve_dataset_dict(prompt_spec)
12 |
--------------------------------------------------------------------------------
/prompt2model/dataset_retriever/task_expansion_prompt.py:
--------------------------------------------------------------------------------
1 | """This module contains the functions to construct the prompt for task expansion."""
2 | METAPROMPT_BASE = "Carefully analyse the task description and examples of the task, and explain the task to give a clearer description. Do not explain each example, but rather capture the general trends. Also place special focus on the format of the input/output examples." # noqa: E501
3 |
4 | TASK = """
5 | Task Description: {task_description}
6 |
7 | Task Examples: {examples}
8 | """
9 |
10 |
11 | def construct_prompt_for_task_explanation(instruction: str, demonstrations: str):
12 | """Constructs prompt for task explanation.
13 |
14 | This is useful for clarifying the requirements of a task,
15 | and providing a clearer description of the task.
16 |
17 | Args:
18 | instruction (str): The task instruction.
19 | demonstrations (str): The task demonstrations.
20 |
21 | Returns:
22 | str: The constructed prompt.
23 | """
24 | task = TASK.format(task_description=instruction, examples=demonstrations)
25 | prompt = "\n--------\n".join([METAPROMPT_BASE, task])
26 | return prompt
27 |
--------------------------------------------------------------------------------
/prompt2model/dataset_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | """Import DatasetGenerator classes."""
2 | from prompt2model.dataset_transformer.base import DatasetTransformer
3 | from prompt2model.dataset_transformer.prompt_based import PromptBasedDatasetTransformer
4 |
5 | __all__ = (
6 | "PromptBasedDatasetTransformer",
7 | "DatasetTransformer",
8 | )
9 |
--------------------------------------------------------------------------------
/prompt2model/dataset_transformer/base.py:
--------------------------------------------------------------------------------
1 | """An interface for dataset transformation."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | from abc import ABC, abstractmethod
6 |
7 | import datasets
8 |
9 | from prompt2model.prompt_parser import PromptSpec
10 |
11 |
12 | class DatasetTransformer(ABC):
13 | """A class for transforming a given dataset to a desired format."""
14 |
15 | @abstractmethod
16 | def transform_data(
17 | self,
18 | prompt_spec: PromptSpec,
19 | dataset: datasets.Dataset,
20 | ) -> datasets.Dataset:
21 | """Transform a split of data.
22 |
23 | Args:
24 | prompt_spec: A prompt spec (containing a system description).
25 | dataset: A dataset split.
26 | num_points_to_transform: Number of data points you wish to
27 | transform. Number must be greater than zero. If number is greater
28 | than size of dataset, whole dataset will be transformed. Ignored
29 | if data_transform is False.
30 |
31 | Returns:
32 | A single dataset split.
33 | """
34 |
--------------------------------------------------------------------------------
/prompt2model/demo_creator/__init__.py:
--------------------------------------------------------------------------------
1 | """Import DemoCreator functions."""
2 |
3 | from prompt2model.demo_creator.create import create_gradio
4 | from prompt2model.demo_creator.mock import mock_gradio_create
5 |
6 | __all__ = (
7 | "mock_gradio_create",
8 | "create_gradio",
9 | )
10 |
--------------------------------------------------------------------------------
/prompt2model/demo_creator/create.py:
--------------------------------------------------------------------------------
1 | """Create a Gradio interface automatically."""
2 |
3 | import gradio as gr
4 | import mdtex2html
5 |
6 | from prompt2model.dataset_processor import TextualizeProcessor
7 | from prompt2model.model_executor import GenerationModelExecutor
8 | from prompt2model.prompt_parser import PromptBasedInstructionParser
9 |
10 |
11 | def create_gradio(
12 | model_executor: GenerationModelExecutor, prompt_parser: PromptBasedInstructionParser
13 | ) -> gr.Blocks:
14 | """Create a Gradio interface automatically.
15 |
16 | Args:
17 | model_executor: A GenerationModelExecutor to expose via a Gradio interface.
18 | prompt_parser: An instance of PromptBasedInstructionParser to parse the prompt.
19 |
20 | Returns:
21 | A Gradio interface for interacting with the model.
22 |
23 | """
24 | description = prompt_parser.instruction
25 | examples = prompt_parser.examples
26 |
27 | def postprocess(self, y):
28 | if y is None:
29 | return []
30 | for i, (message, response) in enumerate(y):
31 | y[i] = (
32 | None if message is None else mdtex2html.convert((message)),
33 | None if response is None else mdtex2html.convert(response),
34 | )
35 | return y
36 |
37 | gr.Chatbot.postprocess = postprocess
38 |
39 | def response(message: str):
40 | if not message.startswith(""):
41 | dataset_processor = TextualizeProcessor(has_encoder=True)
42 | message = dataset_processor.wrap_single_input(
43 | prompt_parser.instruction, message
44 | )
45 | response = model_executor.make_single_prediction(message)
46 | prediction = response.prediction
47 | return prediction
48 |
49 | def chat(message, history):
50 | history = history or []
51 | model_output = (
52 | response(message) if message != "" else "Please give valid input."
53 | )
54 | history.append((message, model_output))
55 | return history, history
56 |
57 | def reset_user_input():
58 | return gr.update(value="")
59 |
60 | def reset_state():
61 | return [], []
62 |
63 | with gr.Blocks() as demo:
64 | gr.HTML("""Prompt2Model
""")
65 | gr.HTML(f"""Task Description: {description}
""")
66 | gr.HTML(f"""Few-shot Examples: {examples}
""")
67 |
68 | chatbot = gr.Chatbot()
69 | with gr.Row():
70 | with gr.Column(scale=4):
71 | with gr.Column(scale=12):
72 | user_input = gr.Textbox(
73 | show_label=False, placeholder="Input...", lines=10
74 | ).style(container=False)
75 | with gr.Column(min_width=32, scale=2):
76 | submitBtn = gr.Button("Submit", variant="primary")
77 | emptyBtn = gr.Button("Clear History")
78 |
79 | history = gr.State([])
80 |
81 | submitBtn.click(
82 | chat, [user_input, history], [chatbot, history], show_progress=True
83 | )
84 | submitBtn.click(reset_user_input, [], [user_input])
85 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
86 |
87 | return demo
88 |
--------------------------------------------------------------------------------
/prompt2model/demo_creator/mock.py:
--------------------------------------------------------------------------------
1 | """An interface for creating Gradio demos automatically."""
2 |
3 | import gradio as gr
4 | import transformers
5 |
6 | from prompt2model.prompt_parser.base import PromptSpec
7 |
8 |
9 | def mock_gradio_create(
10 | model: transformers.PreTrainedModel, prompt_spec: PromptSpec
11 | ) -> gr.Interface:
12 | """Create a Gradio interface automatically.
13 |
14 | Args:
15 | model: A trained model to expose via a Gradio interface.
16 | prompt_spec: A PromptSpec to help choose the visual interface.
17 |
18 | Returns:
19 | A Gradio interface for interacting with the model.
20 |
21 | """
22 | _ = model, prompt_spec # suppress unused variable warnings
23 | dummy_interface = gr.Interface(lambda input: None, "textbox", "label")
24 | return dummy_interface
25 |
--------------------------------------------------------------------------------
/prompt2model/demo_creator/readme.md:
--------------------------------------------------------------------------------
1 | # Demo Creator
2 |
3 | ## Overview
4 |
5 | - **create_gradio**: A function to set up and return a Gradio
6 | interface for model interactions.
7 |
8 | ## Getting Started
9 |
10 | - **Import the Necessary Modules**:
11 |
12 | ```python
13 | from prompt2model.model_executor import GenerationModelExecutor
14 | from prompt2model.prompt_parser import PromptBasedInstructionParser
15 | from prompt2model.gradio_interface import create_gradio
16 | ```
17 |
18 | - **Initialize Components**:
19 |
20 | ```python
21 | model_executor = GenerationModelExecutor(...)
22 | prompt_parser = PromptBasedInstructionParser(...)
23 | # Refer to the documentation of ModelExecutor and PromptParser for details.
24 | ```
25 |
26 | - **Create and Run the Gradio Interface**:
27 |
28 | ```python
29 | interface = create_gradio(model_executor, prompt_parser)
30 | interface.launch(shared=True)
31 | ```
32 |
--------------------------------------------------------------------------------
/prompt2model/model_evaluator/__init__.py:
--------------------------------------------------------------------------------
1 | """Import evaluator classes."""
2 | from prompt2model.model_evaluator.base import ModelEvaluator
3 | from prompt2model.model_evaluator.mock import MockEvaluator
4 | from prompt2model.model_evaluator.seq2seq import Seq2SeqEvaluator
5 |
6 | __all__ = ("MockEvaluator", "ModelEvaluator", "Seq2SeqEvaluator")
7 |
--------------------------------------------------------------------------------
/prompt2model/model_evaluator/base.py:
--------------------------------------------------------------------------------
1 | """An interface for automatic model evaluation."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | import json
6 | from abc import ABC, abstractmethod
7 | from typing import Any
8 |
9 | import datasets
10 | import evaluate
11 |
12 | from prompt2model.model_executor import ModelOutput
13 |
14 |
15 | class ModelEvaluator(ABC):
16 | """An interface for automatic model evaluation."""
17 |
18 | @abstractmethod
19 | def evaluate_model(
20 | self,
21 | dataset: datasets.Dataset,
22 | gt_column: str,
23 | predictions: list[ModelOutput],
24 | model_input_column: str | None = None,
25 | metrics: list[evaluate.Metric] | None = None,
26 | encoder_model_name: str = "xlm-roberta-base",
27 | ) -> dict[str, Any]:
28 | """Evaluate a model on a test set..
29 |
30 | Args:
31 | dataset: The dataset to evaluate metrics on.
32 | gt_column: The dataset column to use as ground truth.
33 | predictions: Model outputs to evaluate.
34 | metrics: (Optional) The metrics to use, defaults to using chr_f,
35 | exact_match, and bert_score.
36 | prompt_spec: (Optional) A PromptSpec to infer the metrics from.
37 |
38 | Returns:
39 | A dictionary of metric values to return.
40 | """
41 |
42 | def write_metrics(self, metrics_dict: dict[str, Any], metrics_path: str) -> None:
43 | """This function writes metrics to a file.
44 |
45 | Args:
46 | metrics_dict: A dictionary of metrics to write.
47 | metrics_path: The file path to write metrics to.
48 |
49 | """
50 | with open(metrics_path, "w") as f:
51 | json.dump(metrics_dict, f)
52 |
--------------------------------------------------------------------------------
/prompt2model/model_evaluator/mock.py:
--------------------------------------------------------------------------------
1 | """A dummy evaluator for testing purposes."""
2 | from __future__ import annotations # noqa FI58
3 |
4 | from typing import Any
5 |
6 | import datasets
7 | import evaluate
8 |
9 | from prompt2model.model_evaluator.base import ModelEvaluator
10 | from prompt2model.model_executor import ModelOutput
11 |
12 |
13 | class MockEvaluator(ModelEvaluator):
14 | """A dummy evaluator that always returns the same metric value."""
15 |
16 | def __init__(self) -> None:
17 | """Initialize the evaluation setting."""
18 |
19 | def evaluate_model(
20 | self,
21 | dataset: datasets.Dataset,
22 | gt_column: str,
23 | predictions: list[ModelOutput],
24 | model_input_column: str | None = None,
25 | metrics: list[evaluate.Metric] | None = None,
26 | encoder_model_name: str = "xlm-roberta-base",
27 | ) -> dict[str, Any]:
28 | """Return empty metrics dictionary.
29 |
30 | Args:
31 | dataset: The dataset to evaluate metrics on.
32 | gt_column: The dataset column to use as ground truth.
33 | predictions: Corresponding model outputs to evaluate.
34 | metrics: (Optional) The metrics to use, defaults to using
35 | chr_f, exact_match, and bert_score.
36 | prompt_spec: (Optional) A PromptSpec to infer the metrics from.
37 |
38 | Returns:
39 | An empty dictionary (for testing purposes).
40 | """
41 | return {}
42 |
--------------------------------------------------------------------------------
/prompt2model/model_evaluator/readme.md:
--------------------------------------------------------------------------------
1 | # Model Evaluator
2 |
3 | ## Overview
4 |
5 | - **ModelEvaluator**: Interface for evaluating models‘ outputs.
6 | - **Seq2SeqEvaluator**: Offers metrics (`ChrF++`, `Exact Match`,
7 | `BERTScore`) for evaluating conditional generation models on specific
8 | datasets.
9 |
10 | ## Getting Started
11 |
12 | - **Import Required Modules**:
13 |
14 | ```python
15 | from prompt2model.evaluator import Seq2SeqEvaluator
16 | from prompt2model.model_executor import ModelOutput
17 | ```
18 |
19 | - **Instantiate Seq2SeqEvaluator**:
20 |
21 | ```python
22 | evaluator = Seq2SeqEvaluator()
23 | ```
24 |
25 | - **Prepare Dataset & Predictions**:
26 |
27 | 1. Autoregressive models might include the input in their outputs. For
28 | such evaluations, refer to the document string of
29 | `model_input_column` in `Seq2SeqEvaluator.evaluate_model`.
30 | 2. Default metrics: `ChrF++`, `Exact Match`, `BERTScore`.
31 |
32 | ```python
33 | PREDICTIONS = [...]
34 | VALIDATION_DATASET = ...
35 | # Dataset with ground truth column `gt_column` and optionally input column `model_input_column`.
36 | ```
37 |
38 | - **Evaluate**:
39 |
40 | ```python
41 | metric_values = evaluator.evaluate_model(
42 | dataset=VALIDATION_DATASET,
43 | gt_column="model_ouput",
44 | predictions=PREDICTIONS,
45 | )
46 | ```
47 |
--------------------------------------------------------------------------------
/prompt2model/model_evaluator/seq2seq.py:
--------------------------------------------------------------------------------
1 | """An interface for automatically evaluate Seq2Seq generation model."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | from typing import Any
6 |
7 | import datasets
8 | import evaluate
9 | import numpy as np
10 |
11 | from prompt2model.model_evaluator.base import ModelEvaluator
12 | from prompt2model.model_executor import ModelOutput
13 | from prompt2model.utils import get_formatted_logger
14 |
15 | logger = get_formatted_logger("ModelEvaluator")
16 |
17 |
18 | class Seq2SeqEvaluator(ModelEvaluator):
19 | """An evaluator computing `chr_f++`, `Exact Match` and `Embedding Distance`."""
20 |
21 | def evaluate_model(
22 | self,
23 | dataset: datasets.Dataset,
24 | gt_column: str,
25 | predictions: list[ModelOutput],
26 | model_input_column: str | None = None,
27 | metrics: list[evaluate.Metric] | None = None,
28 | encoder_model_name: str = "xlm-roberta-base",
29 | ) -> dict[str, Any]:
30 | """Evaluate a model on a test set.
31 |
32 | Args:
33 | dataset: The dataset to evaluate metrics on.
34 | gt_column: The dataset column to use as ground truth.
35 | predictions: Model outputs to evaluate.
36 | model_input_column: (Optional) For autoregistered models,
37 | the prediction sometimes contains the model input.
38 | So we need to delete the model input if it's in the predictions.
39 | metrics: (Optional) The metrics to use, defaults to using
40 | chr_f, exact_match, and bert_score.
41 |
42 | Returns:
43 | A dictionary of metric values to return.
44 | """
45 | if metrics is not None:
46 | metric_names = [each.name for each in metrics]
47 | metric_names = sorted(metric_names, key=lambda name: name.lower())
48 | if not (
49 | set(metric_names)
50 | < {
51 | "chr_f",
52 | "exact_match",
53 | "bert_score",
54 | }
55 | ):
56 | raise ValueError(
57 | "Metrics must be within chr_f, exact_match, and bert_score."
58 | )
59 | logger.info(f"Using selected metrics: {', '.join(metric_names)}.")
60 | else:
61 | logger.info("Using default metrics of chr_f, exact_match and bert_score.")
62 | metrics = [
63 | evaluate.load("chrf"),
64 | evaluate.load("exact_match"),
65 | evaluate.load("bertscore"),
66 | ]
67 | # Get the ground truth from the dataset
68 | ground_truths = dataset[gt_column]
69 | # Extract the predicted strings from ModelOutput
70 | predicted_strings = [each.prediction for each in predictions]
71 | if len(ground_truths) != len(predicted_strings):
72 | raise ValueError(
73 | "The length of input dataset and predictions are not equal."
74 | )
75 | # Initialize the metric values dictionary
76 | metric_values = {}
77 |
78 | if model_input_column is not None:
79 | # Some of the autoregregistered models' output always contains
80 | # the input. So we need to delete the model input if it's in the
81 | # predictions when necessary.
82 | logger.info(
83 | "The model_input_column is not None. The model input will be detached from predictions if necessary." # noqa E501
84 | )
85 | model_inputs = dataset[model_input_column]
86 | for idx, model_input in enumerate(model_inputs):
87 | ground_truth = ground_truths[idx]
88 | predicted_string = predicted_strings[idx]
89 | if (model_input in predicted_string) and (
90 | model_input not in ground_truth
91 | ):
92 | predicted_string = predicted_string.replace(model_input, "")
93 | predicted_strings[idx] = predicted_string
94 |
95 | # Compute and store metric values
96 | for metric in metrics:
97 | metric_name = metric.name
98 | metric.add_batch(predictions=predicted_strings, references=ground_truths)
99 | if metric_name == "chr_f":
100 | metric_values["chr_f++"] = metric.compute(word_order=2)["score"]
101 | elif metric_name == "exact_match":
102 | metric_values[metric_name] = metric.compute()["exact_match"]
103 | elif metric_name == "bert_score":
104 | metric_values["average_bert_score"] = np.average(
105 | metric.compute(model_type=encoder_model_name)["f1"]
106 | )
107 |
108 | return metric_values
109 |
--------------------------------------------------------------------------------
/prompt2model/model_executor/__init__.py:
--------------------------------------------------------------------------------
1 | """Import all the model executor classes."""
2 |
3 | from prompt2model.model_executor.base import ModelExecutor, ModelOutput
4 | from prompt2model.model_executor.generate import GenerationModelExecutor
5 | from prompt2model.model_executor.mock import MockModelExecutor
6 |
7 | __all__ = (
8 | "ModelExecutor",
9 | "ModelOutput",
10 | "MockModelExecutor",
11 | "GenerationModelExecutor",
12 | )
13 |
--------------------------------------------------------------------------------
/prompt2model/model_executor/base.py:
--------------------------------------------------------------------------------
1 | """An interface for generating model outputs."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | from abc import ABC, abstractmethod
6 | from dataclasses import dataclass
7 | from typing import Any
8 |
9 | import datasets
10 | import transformers
11 |
12 | from prompt2model.utils import get_formatted_logger
13 |
14 | logger = get_formatted_logger("ModelExecutor")
15 |
16 |
17 | @dataclass(frozen=False)
18 | class ModelOutput:
19 | """A model output for a single example.
20 |
21 | Attributes:
22 | prediction: The prediction by the model.
23 | auxiliary_info: Any other auxiliary information provided by the model.
24 | """
25 |
26 | prediction: Any
27 | auxiliary_info: dict[str, Any]
28 |
29 |
30 | class ModelExecutor(ABC):
31 | """An interface for automatic model evaluation."""
32 |
33 | def __init__(
34 | self,
35 | model: transformers.PreTrainedModel,
36 | tokenizer: transformers.PreTrainedTokenizer,
37 | batch_size: int = 10,
38 | tokenizer_max_length: int = 256,
39 | sequence_max_length: int = 512,
40 | ) -> None:
41 | """Initializes a new instance of ModelExecutor.
42 |
43 | Args:
44 | model: The model to evaluate.
45 | tokenizer: The model's associated tokenizer.
46 | batch_size: The batch size to use when making predictions.
47 | tokenizer_max_length: The maximum number of tokens that
48 | tokenizer is allowed to generate.
49 | sequence_max_length: The maximum number of tokens in
50 | the input and output.
51 | """
52 | self.model = model
53 | self.tokenizer = tokenizer
54 | self.batch_size = batch_size
55 | if self.tokenizer.pad_token is None:
56 | logger.warning(
57 | "Trying to init an ModelExecutor's tokenizer without pad_token."
58 | )
59 | self.tokenizer.pad_token = self.tokenizer.eos_token
60 | self.model.config.pad_token_id = self.model.config.eos_token_id
61 | self.tokenizer_max_length = tokenizer_max_length
62 | self.sequence_max_length = sequence_max_length
63 | if self.sequence_max_length is None:
64 | max_length = self.model.config.max_length
65 | logger.warning(
66 | (
67 | "The `max_length` in `self.model.generate` will default to "
68 | f"`self.model.config.max_length` ({max_length})"
69 | " if `sequence_max_length` is `None`."
70 | )
71 | )
72 | self.sequence_max_length = max_length
73 | if hasattr(self.model.config, "max_position_embeddings"):
74 | max_embeddings = self.model.config.max_position_embeddings
75 | if sequence_max_length is not None and sequence_max_length > max_embeddings:
76 | logger.warning(
77 | (
78 | f"The sequence_max_length ({sequence_max_length})"
79 | f" is larger than the max_position_embeddings ({max_embeddings})." # noqa: E501
80 | f" So the sequence_max_length will be set to {max_embeddings}." # noqa: E501
81 | )
82 | )
83 | self.sequence_max_length = max_embeddings
84 |
85 | @abstractmethod
86 | def make_prediction(
87 | self,
88 | test_set: datasets.Dataset,
89 | input_column: str,
90 | ) -> list[ModelOutput]:
91 | """Evaluate a model on a test set.
92 |
93 | Args:
94 | test_set: The dataset to make predictions on.
95 | input_column: The dataset column to use as input to the model.
96 |
97 | Returns:
98 | A list of model outputs, one for each element in the test set.
99 | """
100 |
101 | @abstractmethod
102 | def make_single_prediction(self, model_input: str) -> ModelOutput:
103 | """Make prediction on one example.
104 |
105 | Args:
106 | model_input: The input string to the model.
107 |
108 | Returns:
109 | A single model output, useful for exposing a model to a user interface.
110 | """
111 |
--------------------------------------------------------------------------------
/prompt2model/model_executor/generate.py:
--------------------------------------------------------------------------------
1 | """Model executor for generative models, including T5-type and GPT-type."""
2 | from __future__ import annotations # noqa FI58
3 |
4 | from typing import Any
5 |
6 | import datasets
7 | import torch
8 |
9 | from prompt2model.model_executor import ModelExecutor, ModelOutput
10 | from prompt2model.utils import get_formatted_logger
11 |
12 | logger = get_formatted_logger("ModelExecutor")
13 |
14 |
15 | class GenerationModelExecutor(ModelExecutor):
16 | """Model executor for T5-type and GPT-type models."""
17 |
18 | def generate(
19 | self,
20 | input_ids: list[torch.Tensor],
21 | attention_mask: list[torch.Tensor],
22 | hyperparameter_choices: dict[str, Any],
23 | ) -> list[torch.Tensor]:
24 | """Generates sequences of token IDs using the model.
25 |
26 | Args:
27 | input_ids: A list of token ID sequences.
28 | attention_mask: A list of binary masks indicating attended tokens.
29 | hyperparameter_choices: A dictionary of hyperparameters for inference.
30 |
31 | Returns:
32 | A list of model output tensors, one for each element in input_ids.
33 | """
34 | generate_strategy = hyperparameter_choices.get("generate_strategy", "greedy")
35 | if generate_strategy not in [
36 | "beam", # beam search.
37 | "top_k", # top_k sampling.
38 | "top_p", # top_p sampling.
39 | "greedy", # greedy search.
40 | "intersect", # If both top_k and top_p are set, the model will
41 | # sample from the intersection of the top-k tokens and the top-p tokens.
42 | ]:
43 | raise ValueError(
44 | "Only top_k/top_p/intersect sampling and beam/greedy "
45 | "search are supported for inference."
46 | )
47 | if generate_strategy == "greedy":
48 | output = self.model.generate(
49 | input_ids=input_ids,
50 | attention_mask=attention_mask,
51 | max_length=self.sequence_max_length,
52 | eos_token_id=self.model.config.eos_token_id,
53 | early_stopping=True,
54 | repetition_penalty=hyperparameter_choices.get(
55 | "repetition_penalty", 2.0
56 | ),
57 | )
58 | elif generate_strategy == "beam":
59 | output = self.model.generate(
60 | input_ids=input_ids,
61 | attention_mask=attention_mask,
62 | max_length=self.sequence_max_length,
63 | eos_token_id=self.model.config.eos_token_id,
64 | early_stopping=True,
65 | do_sample=False,
66 | repetition_penalty=hyperparameter_choices.get(
67 | "repetition_penalty", 2.0
68 | ),
69 | num_beams=hyperparameter_choices.get("num_beams", 3),
70 | )
71 | elif generate_strategy == "top_k":
72 | output = self.model.generate(
73 | input_ids=input_ids,
74 | attention_mask=attention_mask,
75 | max_length=self.sequence_max_length,
76 | eos_token_id=self.model.config.eos_token_id,
77 | early_stopping=True,
78 | do_sample=True,
79 | repetition_penalty=hyperparameter_choices.get(
80 | "repetition_penalty", 2.0
81 | ),
82 | top_k=hyperparameter_choices.get("top_k", 20),
83 | )
84 | elif generate_strategy == "top_p":
85 | output = self.model.generate(
86 | input_ids=input_ids,
87 | attention_mask=attention_mask,
88 | max_length=self.sequence_max_length,
89 | eos_token_id=self.model.config.eos_token_id,
90 | early_stopping=True,
91 | do_sample=True,
92 | repetition_penalty=hyperparameter_choices.get(
93 | "repetition_penalty", 2.0
94 | ),
95 | top_p=hyperparameter_choices.get("top_p", 0.95),
96 | )
97 | else:
98 | # For intersect sampling.
99 | output = self.model.generate(
100 | input_ids=input_ids,
101 | attention_mask=attention_mask,
102 | max_length=self.sequence_max_length,
103 | eos_token_id=self.model.config.eos_token_id,
104 | early_stopping=True,
105 | do_sample=True,
106 | repetition_penalty=hyperparameter_choices.get(
107 | "repetition_penalty", 2.0
108 | ),
109 | top_k=hyperparameter_choices.get("top_k", 20),
110 | top_p=hyperparameter_choices.get("top_p", 0.95),
111 | )
112 | return output
113 |
114 | def make_prediction(
115 | self,
116 | test_set: datasets.Dataset,
117 | input_column: str,
118 | hyperparameter_choices: dict[str, Any] = {},
119 | ) -> list[ModelOutput]:
120 | """Make predictions with a T5-type or GPT-type model on a test set.
121 |
122 | Args:
123 | test_set: The dataset to make predictions on. Note that
124 | make_single_prediction will warp single_model_input
125 | into a inference_dataset with only one element.
126 | input_column: The dataset column to use as input to the model.
127 | hyperparameter_choices: A dictionary of hyperparameter for generate.
128 |
129 | Returns:
130 | A list of model outputs, one for each element in the test set.
131 | """
132 | expected_num_examples = len(test_set)
133 | model_outputs = []
134 | longest_input = max(test_set[input_column], key=len)
135 | if (
136 | self.tokenizer_max_length is not None
137 | and len(self.tokenizer.tokenize(longest_input)) > self.tokenizer_max_length
138 | ):
139 | logger.warning(
140 | (
141 | "Truncation happened when tokenizing dataset / input string."
142 | " You should consider increasing the tokenizer_max_length."
143 | " Otherwise the truncation may lead to unexpected results."
144 | )
145 | )
146 |
147 | for start_idx in range(0, expected_num_examples, self.batch_size):
148 | end_idx = min(start_idx + self.batch_size, expected_num_examples)
149 | batch = datasets.Dataset.from_dict(test_set[start_idx:end_idx])
150 |
151 | input_texts = batch[input_column]
152 | encoded_inputs = self.tokenizer.batch_encode_plus(
153 | input_texts,
154 | truncation=True,
155 | max_length=self.tokenizer_max_length,
156 | padding=True,
157 | return_tensors="pt",
158 | )
159 | device = self.model.device
160 | input_ids = encoded_inputs["input_ids"].to(device)
161 | attention_mask = encoded_inputs["attention_mask"].to(device)
162 | output = self.generate(
163 | input_ids=input_ids,
164 | attention_mask=attention_mask,
165 | hyperparameter_choices=hyperparameter_choices,
166 | )
167 |
168 | for idx, input_text in enumerate(input_texts):
169 | logits = output[idx]
170 | decoded_output = self.tokenizer.decode(logits, skip_special_tokens=True)
171 | model_output = ModelOutput(
172 | prediction=decoded_output,
173 | auxiliary_info={
174 | "input_text": input_text,
175 | "logits": logits,
176 | },
177 | )
178 | model_outputs.append(model_output)
179 |
180 | return model_outputs
181 |
182 | def make_single_prediction(
183 | self, model_input: str, hyperparameter_choices: dict[str, Any] = {}
184 | ) -> ModelOutput:
185 | """Mock evaluation on one example.
186 |
187 | Args:
188 | model_input: The input string to the model.
189 | hyperparameter_choices: A dictionary of hyperparameter for inference.
190 |
191 | Returns:
192 | A single model output, useful for exposing a model to a user interface.
193 | """
194 | expected_num_examples = 1
195 | inference_dataset = datasets.Dataset.from_dict({"model_input": [model_input]})
196 | inference_column = "model_input"
197 | if len(inference_dataset) != expected_num_examples:
198 | raise ValueError(
199 | f"Expected {expected_num_examples} examples, "
200 | f"but got {len(inference_dataset)}."
201 | )
202 | model_output = self.make_prediction(
203 | inference_dataset,
204 | inference_column,
205 | hyperparameter_choices,
206 | )[0]
207 | return model_output
208 |
--------------------------------------------------------------------------------
/prompt2model/model_executor/mock.py:
--------------------------------------------------------------------------------
1 | """A dummy class to generate model outputs (for testing purposes)."""
2 | from __future__ import annotations
3 |
4 | import datasets
5 |
6 | from prompt2model.model_executor import ModelExecutor, ModelOutput
7 |
8 |
9 | class MockModelExecutor(ModelExecutor):
10 | """An interface for automatic model evaluation."""
11 |
12 | def make_prediction(
13 | self,
14 | test_set: datasets.Dataset,
15 | input_column: str,
16 | ) -> list[ModelOutput]:
17 | """Mock the execution of a model on a test set.
18 |
19 | Args:
20 | test_set: The dataset to make predictions on.
21 | input_column: The dataset column to use as input to the model.
22 |
23 | Returns:
24 | An object containing model outputs.
25 | """
26 | predictions = []
27 | for _ in test_set[input_column]:
28 | model_output = ModelOutput(prediction="", auxiliary_info={})
29 | predictions.append(model_output)
30 | return predictions
31 |
32 | def make_single_prediction(self, model_input: str) -> ModelOutput:
33 | """Mock evaluation on one example.
34 |
35 | Args:
36 | model_input: The input string to the model.
37 |
38 | Returns:
39 | A single model output, useful for exposing a model to a user interface.
40 | """
41 | _ = model_input
42 | model_output = ModelOutput(prediction="", auxiliary_info={})
43 | return model_output
44 |
--------------------------------------------------------------------------------
/prompt2model/model_executor/readme.md:
--------------------------------------------------------------------------------
1 | # Model Executor
2 |
3 | ## Overview
4 |
5 | - **ModelExecutor**: An interface for executing predictions across
6 | various models.
7 | - **GenerationModelExecutor**: Tailored for generative models such as
8 | T5 (encoder-decoder) and GPT (autoregressive). It supports multiple
9 | generation strategies.
10 | - **ModelOutput**: Denotes the output from the model for a given
11 | example, consolidating the prediction and any auxiliary information.
12 |
13 | ## Getting Started
14 |
15 | - **Import the Necessary Modules**:
16 |
17 | ```python
18 | from prompt2model.model_executor import GenerationModelExecutor, ModelOutput
19 | ```
20 |
21 | - **Prepare the Input Data**:
22 |
23 | ```python
24 | input_dataset = ... # A dataset containing input examples.
25 | input_example = ... # A singular input in string format.
26 | ```
27 |
28 | - **Initialize the ModelExecutor**:
29 |
30 | ```python
31 | model = ... # A HuggingFace model instance.
32 | tokenizer = ... # A corresponding HuggingFace tokenizer.
33 | model_executor = GenerationModelExecutor(model, tokenizer)
34 | ```
35 |
36 | - **Generate Predictions**:
37 |
38 | For multiple inputs:
39 |
40 | ```python
41 | outputs = model_executor.make_prediction(
42 | test_set=input_dataset, # A dataset object.
43 | input_column="..."
44 | # The input column is the name of the column containing the input in the input_dataset.
45 | )
46 | ```
47 |
48 | For a single input:
49 |
50 | ```python
51 | test_input = "..."
52 | output = model_executor.make_single_prediction(test_input)
53 | ```
54 |
55 | - **Choose a Generation Strategy**:
56 |
57 | Specify the desired decoding strategy. For more details, see the
58 | document string `GenerationModelExecutor.generate`.
59 |
60 | ```python
61 | hyperparameters = {"generate_strategy": "beam", "num_beams": 4}
62 | model_output = model_executor.make_single_prediction(test_input, hyperparameters)
63 | ```
64 |
--------------------------------------------------------------------------------
/prompt2model/model_retriever/__init__.py:
--------------------------------------------------------------------------------
1 | """Import all the model executor classes."""
2 |
3 | from prompt2model.model_retriever.base import ModelRetriever
4 | from prompt2model.model_retriever.description_based_retriever import (
5 | DescriptionModelRetriever,
6 | )
7 | from prompt2model.model_retriever.mock import MockModelRetriever
8 |
9 | __all__ = ("ModelRetriever", "DescriptionModelRetriever", "MockModelRetriever")
10 |
--------------------------------------------------------------------------------
/prompt2model/model_retriever/base.py:
--------------------------------------------------------------------------------
1 | """An interface for model selection."""
2 | from __future__ import annotations
3 |
4 | from abc import ABC, abstractmethod
5 |
6 | from prompt2model.prompt_parser import PromptSpec
7 |
8 |
9 | # pylint: disable=too-few-public-methods
10 | class ModelRetriever(ABC):
11 | """Retrieve several models from HuggingFace."""
12 |
13 | @abstractmethod
14 | def retrieve(
15 | self,
16 | prompt: PromptSpec,
17 | ) -> list[str]:
18 | """Retrieve relevant models from HuggingFace.
19 |
20 | Args:
21 | prompt: A prompt to use to select relevant models.
22 |
23 | Return:
24 | A list of relevant models' HuggingFace names.
25 | """
26 |
--------------------------------------------------------------------------------
/prompt2model/model_retriever/generate_hypothetical_document.py:
--------------------------------------------------------------------------------
1 | """Tools for generating hypothetical documents from prompts."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | import logging
6 |
7 | from prompt2model.prompt_parser import PromptSpec
8 | from prompt2model.utils import API_ERRORS, api_tools, handle_api_error
9 |
10 | PROMPT_PREFIX = """HuggingFace contains models, which are each given a user-generated description. The first section of the description, delimited with two "---" lines, consists of a YAML description of the model. This may contain fields like "language" (supported by model), "datasets" (used to train the model), "tags" (e.g. tasks relevant to the model), and "metrics" (used to evaluate the model). Create a hypothetical HuggingFace model description that would satisfy a given user instruction. Here are some examples:
11 |
12 | Instruction: "Give me some translation from English to Vietnamese. Input English and output Vietnamese."
13 | Hypothetical model description:
14 | ---
15 | language:
16 | - en
17 | - vi
18 |
19 | tags:
20 | - translation
21 |
22 | license: apache-2.0
23 | ---
24 |
25 | ### eng-vie
26 |
27 | * source group: English
28 | * target group: Vietnamese
29 | * OPUS readme: [eng-vie](https://github.com/Helsinki-NLP/Tatoeba-Challenge/tree/master/models/eng-vie/README.md)
30 |
31 | * model: transformer-align
32 | * source language(s): eng
33 | * target language(s): vie vie_Hani
34 | * model: transformer-align
35 | * pre-processing: normalization + SentencePiece (spm32k,spm32k)
36 | * a sentence initial language token is required in the form of `>>id<<` (id = valid target language ID)
37 | * download original weights: [opus-2020-06-17.zip](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.zip)
38 | * test set translations: [opus-2020-06-17.test.txt](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.test.txt)
39 | * test set scores: [opus-2020-06-17.eval.txt](https://object.pouta.csc.fi/Tatoeba-MT-models/eng-vie/opus-2020-06-17.eval.txt)
40 |
41 | ## Benchmarks
42 |
43 | | testset | BLEU | chr-F |
44 | |-----------------------|-------|-------|
45 | | Tatoeba-test.eng.vie | 37.2 | 0.542 |
46 |
47 |
48 | ### System Info:
49 | - hf_name: eng-vie
50 |
51 | - source_languages: eng
52 |
53 | - target_languages: vie
54 |
55 | - opus_readme_url: https://github.com/Helsinki-NLP/Tatoeba-Challenge/tree/master/models/eng-vie/README.md
56 |
57 | - original_repo: Tatoeba-Challenge
58 |
59 | - tags: ['translation']
60 |
61 | - languages: ['en', 'vi']
62 |
63 | - src_constituents: {'eng'}
64 |
65 | - tgt_constituents: {'vie', 'vie_Hani'}
66 |
67 | - src_multilingual: False
68 |
69 | - tgt_multilingual: False
70 |
71 | - prepro: normalization + SentencePiece (spm32k,spm32k)
72 |
73 | - src_alpha3: eng
74 |
75 | - tgt_alpha3: vie
76 |
77 | - short_pair: en-vi
78 |
79 | - chrF2_score: 0.542
80 |
81 | - bleu: 37.2
82 |
83 | - brevity_penalty: 0.973
84 |
85 | - ref_len: 24427.0
86 |
87 | - src_name: English
88 |
89 | - tgt_name: Vietnamese
90 |
91 | - train_date: 2020-06-17
92 |
93 | - src_alpha2: en
94 |
95 | - tgt_alpha2: vi
96 |
97 | - prefer_old: False
98 |
99 | - long_pair: eng-vie
100 |
101 |
102 | Instruction: "I want to summarize things like news articles."
103 | Hypothetical model description:
104 | ---
105 | language: en
106 | license: apache-2.0
107 | tags:
108 | - pegasus
109 | - seq2seq
110 | - summarization
111 | model-index:
112 | - name: tuner007/pegasus_summarizer
113 | results:
114 | - task:
115 | type: summarization
116 | name: Summarization
117 | dataset:
118 | name: cnn_dailymail
119 | type: cnn_dailymail
120 | config: 3.0.0
121 | split: train
122 | metrics:
123 | - name: ROUGE-1
124 | type: rouge
125 | value: 36.604
126 | verified: true
127 | - name: ROUGE-2
128 | type: rouge
129 | value: 14.6398
130 | verified: true
131 | - name: ROUGE-L
132 | type: rouge
133 | value: 23.8845
134 | verified: true
135 | ---
136 |
137 | ## Model description
138 | [PEGASUS](https://github.com/google-research/pegasus) fine-tuned for summarization
139 |
140 | > Created by [Arpit Rajauria](https://twitter.com/arpit_rajauria)
141 | [](https://twitter.com/arpit_rajauria)
142 |
143 |
144 | ### Framework versions
145 |
146 | - Transformers 4.31.0
147 | - Pytorch 2.0.1+cu118
148 | - Datasets 2.13.1
149 | - Tokenizers 0.13.3
150 |
151 |
152 | Instruction: "I want to classify sentences by their sentiment (positive/negative/neutral)."
153 | Hypothetical model description:
154 | ---
155 | language: en
156 | license: apache-2.0
157 | datasets:
158 | - sst2
159 | - glue
160 | model-index:
161 | - name: distilbert-base-uncased-finetuned-sst-2-english
162 | results:
163 | - task:
164 | type: text-classification
165 | name: Text Classification
166 | dataset:
167 | name: glue
168 | type: glue
169 | config: sst2
170 | split: validation
171 | metrics:
172 | - type: accuracy
173 | value: 0.9105504587155964
174 | name: Accuracy
175 | verified: true
176 | verifyToken: eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJoYXNoIjoiN2YyOGMxYjY2Y2JhMjkxNjIzN2FmMjNiNmM2ZWViNGY3MTNmNWI2YzhiYjYxZTY0ZGUyN2M1NGIxZjRiMjQwZiIsInZlcnNpb24iOjF9.uui0srxV5ZHRhxbYN6082EZdwpnBgubPJ5R2-Wk8HTWqmxYE3QHidevR9LLAhidqGw6Ih93fK0goAXncld_gBg
177 | ---
178 |
179 | # DistilBERT base uncased finetuned SST-2
180 |
181 | ## Table of Contents
182 | - [Model Details](#model-details)
183 | - [How to Get Started With the Model](#how-to-get-started-with-the-model)
184 | - [Uses](#uses)
185 | - [Risks, Limitations and Biases](#risks-limitations-and-biases)
186 | - [Training](#training)
187 |
188 | ## Model Details
189 | **Model Description:** This model is a fine-tune checkpoint of [DistilBERT-base-uncased](https://huggingface.co/distilbert-base-uncased), fine-tuned on SST-2.
190 | This model reaches an accuracy of 91.3 on the dev set (for comparison, Bert bert-base-uncased version reaches an accuracy of 92.7).
191 | - **Developed by:** Hugging Face
192 | - **Model Type:** Text Classification
193 | - **Language(s):** English
194 | - **License:** Apache-2.0
195 | - **Parent Model:** For more details about DistilBERT, we encourage users to check out [this model card](https://huggingface.co/distilbert-base-uncased).
196 | - **Resources for more information:**
197 | - [Model Documentation](https://huggingface.co/docs/transformers/main/en/model_doc/distilbert#transformers.DistilBertForSequenceClassification)
198 | - [DistilBERT paper](https://arxiv.org/abs/1910.01108)
199 |
200 | ## Uses
201 |
202 | #### Direct Use
203 |
204 | This model can be used for topic classification. You can use the raw model for either masked language modeling or next sentence prediction, but it's mostly intended to be fine-tuned on a downstream task. See the model hub to look for fine-tuned versions on a task that interests you.
205 |
206 | #### Misuse and Out-of-scope Use
207 | The model should not be used to intentionally create hostile or alienating environments for people. In addition, the model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
208 |
209 |
210 | ## Risks, Limitations and Biases
211 |
212 | Based on a few experimentations, we observed that this model could produce biased predictions that target underrepresented populations.
213 |
214 | For instance, for sentences like `This film was filmed in COUNTRY`, this binary classification model will give radically different probabilities for the positive label depending on the country (0.89 if the country is France, but 0.08 if the country is Afghanistan) when nothing in the input indicates such a strong semantic shift.
215 |
216 | # Training
217 |
218 |
219 | #### Training Data
220 |
221 |
222 | The authors use the following Stanford Sentiment Treebank([sst2](https://huggingface.co/datasets/sst2)) corpora for the model.
223 | ```
224 | :""" # noqa: E501
225 |
226 |
227 | def generate_hypothetical_model_description(
228 | prompt: PromptSpec, max_api_calls: int = None
229 | ) -> str:
230 | """Generate a hypothetical model description for the user's instruction.
231 |
232 | This method is based on HyDE by Gao et al 2022 (https://arxiv.org/abs/2212.10496).
233 |
234 | Args:
235 | prompt: PromptSpec object containing the user's instruction.
236 |
237 | Returns:
238 | a hypothetical model description for the user's instruction.
239 | """
240 | if max_api_calls and max_api_calls <= 0:
241 | raise ValueError("max_api_calls must be > 0.")
242 | api_call_counter = 0
243 |
244 | instruction = prompt.instruction
245 | api_agent = api_tools.default_api_agent
246 | chatgpt_prompt = (
247 | PROMPT_PREFIX
248 | + "\n"
249 | + f'Instruction: "{instruction}"\nHypothetical model description:\n'
250 | )
251 | while True:
252 | try:
253 | chatgpt_completion = api_agent.generate_one_completion(
254 | chatgpt_prompt,
255 | temperature=0.0,
256 | presence_penalty=0.0,
257 | frequency_penalty=0.0,
258 | )
259 | return chatgpt_completion.choices[0]["message"]["content"]
260 | except API_ERRORS as e:
261 | handle_api_error(e)
262 | api_call_counter += 1
263 | if max_api_calls and api_call_counter >= max_api_calls:
264 | logging.error("Maximum number of API calls reached.")
265 | raise ValueError("Maximum number of API calls reached.") from e
266 |
--------------------------------------------------------------------------------
/prompt2model/model_retriever/mock.py:
--------------------------------------------------------------------------------
1 | """An interface for model selection."""
2 | from __future__ import annotations
3 |
4 | from prompt2model.model_retriever import ModelRetriever
5 | from prompt2model.prompt_parser import PromptSpec
6 |
7 |
8 | class MockModelRetriever(ModelRetriever):
9 | """Select a fixed model from among a set of hyperparameter choices."""
10 |
11 | def __init__(self, fixed_model_name: str):
12 | """Initialize a dummy retriever that returns a fixed model name."""
13 | self.fixed_model_name = fixed_model_name
14 |
15 | def retrieve(
16 | self,
17 | prompt: PromptSpec,
18 | ) -> list[str]:
19 | """Select an arbitrary, fixed model from HuggingFace.
20 |
21 | Args:
22 | prompt: A prompt to use to select relevant models.
23 |
24 | Return:
25 | A relevant model's HuggingFace name.
26 | """
27 | return [self.fixed_model_name]
28 |
--------------------------------------------------------------------------------
/prompt2model/model_retriever/readme.md:
--------------------------------------------------------------------------------
1 | # Model Retriever
2 |
3 | ## Overview
4 |
5 | - `ModelRetriever`: An interface for retrieving several models from
6 | HuggingFace.
7 | - `DescriptionModelRetriever` offers functions to vectorize
8 | descriptions, adjust relevance scores, build a BM25 index, and
9 | retrieve models based on a prompt.
10 | - `ModelInfo` stores a model's HuggingFace name, description,
11 | identifier, relevance score, disk size, and download count.
12 |
13 | ## Getting Started
14 |
15 | - **Import Required Modules**:
16 |
17 | ```python
18 | from prompt2model.model_retriever import DescriptionModelRetriever
19 | from prompt2model.prompt_parser import MockPromptSpec, TaskType
20 | ```
21 |
22 | - **Initialize the Prompt**:
23 |
24 | ```python
25 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION)
26 | prompt = "..."
27 | prompt_spec._instruction = prompt
28 | ```
29 |
30 | - **Initialize and Run the Retriever**:
31 |
32 | ```python
33 | retriever = DescriptionModelRetriever(
34 | search_index_path="path_to_bm25_search_index.pkl",
35 | model_descriptions_index_path="path_to_model_info_directory",
36 | use_bm25=True,
37 | use_HyDE=True,
38 | )
39 | top_model_name = retriever.retrieve(prompt_spec)
40 | ```
41 |
42 | ## Notes
43 |
44 | - Ensure that the paths provided to the retriever (e.g.,
45 | `search_index_path` and `model_descriptions_index_path`) point to the
46 | correct locations.
47 |
--------------------------------------------------------------------------------
/prompt2model/model_retriever/run_model_retriever.py:
--------------------------------------------------------------------------------
1 | """Script to run the model retriever in isolation."""
2 |
3 | from prompt2model.model_retriever import DescriptionModelRetriever
4 | from prompt2model.prompt_parser import MockPromptSpec, TaskType
5 |
6 | if __name__ == "__main__":
7 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION)
8 | prompt = """Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.""" # noqa E501
9 | prompt_spec._instruction = prompt
10 | retriever = DescriptionModelRetriever(
11 | search_index_path="huggingface_data/huggingface_models/bm25_search_index.pkl",
12 | model_descriptions_index_path="huggingface_data/huggingface_models/model_info/",
13 | use_bm25=True,
14 | use_HyDE=True,
15 | )
16 | top_model_name = retriever.retrieve(prompt_spec)
17 |
--------------------------------------------------------------------------------
/prompt2model/model_trainer/__init__.py:
--------------------------------------------------------------------------------
1 | """Import BaseTrainer classes."""
2 | from prompt2model.model_trainer.base import BaseTrainer
3 | from prompt2model.model_trainer.generate import GenerationModelTrainer
4 | from prompt2model.model_trainer.mock import MockTrainer
5 |
6 | __all__ = ("MockTrainer", "BaseTrainer", "GenerationModelTrainer")
7 |
--------------------------------------------------------------------------------
/prompt2model/model_trainer/base.py:
--------------------------------------------------------------------------------
1 | """An base class for trainers."""
2 | from __future__ import annotations # noqa FI58
3 |
4 | from abc import ABC, abstractmethod
5 | from typing import Any
6 |
7 | import datasets
8 | import transformers
9 | from transformers import AutoModel, AutoTokenizer
10 |
11 |
12 | # pylint: disable=too-few-public-methods
13 | class BaseTrainer(ABC):
14 | """Train a model with a fixed set of hyperparameters."""
15 |
16 | def __init__(self, pretrained_model_name: str):
17 | """Initialize a model trainer.
18 |
19 | Args:
20 | pretrained_model_name: A HuggingFace model name to use for training.
21 | """
22 | self.model = AutoModel.from_pretrained(pretrained_model_name)
23 | self.tokenizer = AutoTokenizer.from_pretrained(
24 | pretrained_model_name, padding_side="left"
25 | )
26 | self.wandb = None
27 |
28 | @abstractmethod
29 | def train_model(
30 | self,
31 | hyperparameter_choices: dict[str, Any],
32 | training_datasets: list[datasets.Dataset],
33 | validation_datasets: list[datasets.Dataset] | None = None,
34 | ) -> tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]:
35 | """Train a model with the given hyperparameters and return it."""
36 |
--------------------------------------------------------------------------------
/prompt2model/model_trainer/callback.py:
--------------------------------------------------------------------------------
1 | """The real evaluation will be conduted after each mock evaluation of Trainer."""
2 |
3 |
4 | from transformers import TrainerCallback
5 |
6 | from prompt2model.model_evaluator import Seq2SeqEvaluator
7 | from prompt2model.model_executor import GenerationModelExecutor
8 | from prompt2model.utils import get_formatted_logger
9 |
10 | logger = get_formatted_logger("ModelTrainer")
11 |
12 |
13 | class ValidationCallback(TrainerCallback):
14 | """The real evaluation will be conduted after each mock evaluation of Trainer."""
15 |
16 | def __init__(
17 | self,
18 | trainer,
19 | tokenizer,
20 | val_dataset,
21 | executor_batch_size=10,
22 | tokenizer_max_length=256,
23 | sequence_max_length=512,
24 | ) -> None:
25 | """Initializes a new instance of Model Trainer Callback.
26 |
27 | Args:
28 | trainer: Trainer instance.
29 | After each epoch of Training, this callback will be called.
30 | tokenizer: Tokenizer to initialize model executor.
31 | val_dataset: Validation dataset to be evaluated on.
32 | executor_batch_size: The batch size for model executor to
33 | make predictions.
34 | tokenizer_max_length: The maximum number of tokens that
35 | tokenizer is allowed to generate.
36 | sequence_max_length: The maximum number of tokens in
37 | the input and output.
38 | """
39 | super().__init__()
40 | self.trainer = trainer
41 | self.tokenizer = tokenizer
42 | self.val_dataset = val_dataset
43 | self.epoch_count = 0
44 | self.val_dataset_size = len(self.val_dataset)
45 | self.executor_batch_size = executor_batch_size
46 | self.tokenizer_max_length = tokenizer_max_length
47 | self.sequence_max_length = sequence_max_length
48 |
49 | def on_epoch_end(self, args, state, control, **kwargs):
50 | """After each evaluation, this function will be called."""
51 | _ = (args, state, control, kwargs)
52 | # Suppress the unused parameters warning.
53 | self.epoch_count += 1
54 | logger.info(
55 | f"Epoch: {self.epoch_count}. Evaluate on { self.val_dataset_size} examples."
56 | )
57 | # For multi-GPU training, the training processor will be segmented
58 | # into multi-threads with data paralyzation, so the validation dataset
59 | # used in the callback is also segmented.
60 | model_executor = GenerationModelExecutor(
61 | model=self.trainer.model,
62 | tokenizer=self.tokenizer,
63 | batch_size=self.executor_batch_size,
64 | tokenizer_max_length=self.tokenizer_max_length,
65 | sequence_max_length=self.sequence_max_length,
66 | )
67 | model_outputs = model_executor.make_prediction(
68 | self.val_dataset,
69 | "model_input",
70 | )
71 | evaluator = Seq2SeqEvaluator()
72 | metric_values = evaluator.evaluate_model(
73 | self.val_dataset,
74 | "model_output",
75 | model_outputs,
76 | encoder_model_name="xlm-roberta-base",
77 | )
78 | logger.info(metric_values)
79 |
--------------------------------------------------------------------------------
/prompt2model/model_trainer/mock.py:
--------------------------------------------------------------------------------
1 | """This module provides a dummy trainer for testing purposes."""
2 | from __future__ import annotations # noqa FI58
3 |
4 | from typing import Any
5 |
6 | import datasets
7 | from transformers import PreTrainedModel # noqa
8 | from transformers import PreTrainedTokenizer
9 |
10 | from prompt2model.model_trainer import BaseTrainer
11 |
12 |
13 | class MockTrainer(BaseTrainer):
14 | """This dummy trainer does not actually train anything."""
15 |
16 | def train_model(
17 | self,
18 | hyperparameter_choices: dict[str, Any],
19 | training_datasets: list[datasets.Dataset],
20 | validation_datasets: list[datasets.Dataset] | None = None,
21 | ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
22 | """This dummy trainer returns the given model without any training.
23 |
24 | Args:
25 | training_datasets: A list of training datasets.
26 | hyperparameter_choices: A dictionary of hyperparameters for training.
27 |
28 | Returns:
29 | A HuggingFace model and tokenizer.
30 | """
31 | _ = training_datasets, hyperparameter_choices, validation_datasets
32 | return self.model, self.tokenizer
33 |
--------------------------------------------------------------------------------
/prompt2model/model_trainer/readme.md:
--------------------------------------------------------------------------------
1 | # Model Trainer
2 |
3 | ## Overview
4 |
5 | - `BaseTrainer`: A base class for model training.
6 | - `GenerationModelTrainer`: A specialized class for training T5-type
7 | (encoder-decoder) and GPT-type (decoder-only) models.
8 | - `ValidationCallback`: A class for performing validation during
9 | training.
10 |
11 | ## Getting Started
12 |
13 | ### Import Modules
14 |
15 | ```python
16 | from prompt2model.model_trainer.generate import GenerationModelTrainer
17 | ```
18 |
19 | ### Initialize the Trainer
20 |
21 | ```python
22 | pretrained_model_name = "..." # Replace with a HuggingFace pretrained model name.
23 | has_encoder =
24 | # Set to True if the model has an encoder, otherwise False.
25 | executor_batch_size = # Set the batch size for model validation.
26 | tokenizer_max_length = # Set the maximum length for the tokenizer.
27 | sequence_max_length = # Set the maximum sequence length.
28 | trainer = GenerationModelTrainer(
29 | pretrained_model_name,
30 | has_encoder,
31 | executor_batch_size,
32 | tokenizer_max_length,
33 | sequence_max_length
34 | )
35 | # For more details, refer to the docstring of GenerationModelTrainer.
36 | ```
37 |
38 | ### Prepare Dataset and Hyperparameters
39 |
40 | ```python
41 | training_datasets = [] # A list of training datasets.
42 | validation_datasets = [] # A list of validation datasets.
43 | hyperparameter_choices = {...} # A dictionary of hyperparameters.
44 | # For more details, refer to the doc string of GenerationModelTrainer.train_model.
45 | ```
46 |
47 | ### Train the Model
48 |
49 | ```python
50 | trained_model, trained_tokenizer = trainer.train_model(
51 | hyperparameter_choices,
52 | training_datasets,
53 | validation_datasets
54 | )
55 | ```
56 |
--------------------------------------------------------------------------------
/prompt2model/param_selector/__init__.py:
--------------------------------------------------------------------------------
1 | """Import model selector classes."""
2 | from prompt2model.param_selector.base import ParamSelector
3 | from prompt2model.param_selector.mock import MockParamSelector
4 | from prompt2model.param_selector.search_with_optuna import OptunaParamSelector
5 |
6 | __all__ = ("MockParamSelector", "ParamSelector", "OptunaParamSelector")
7 |
--------------------------------------------------------------------------------
/prompt2model/param_selector/base.py:
--------------------------------------------------------------------------------
1 | """An interface for model selection."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | from abc import ABC, abstractmethod
6 | from typing import Any
7 |
8 | import datasets
9 | import transformers
10 |
11 |
12 | # pylint: disable=too-few-public-methods
13 | class ParamSelector(ABC):
14 | """Select a good model from among a set of hyperparameter choices."""
15 |
16 | @abstractmethod
17 | def select_from_hyperparameters(
18 | self,
19 | training_sets: list[datasets.Dataset],
20 | validation: datasets.Dataset,
21 | hyperparameters: dict[str, list[Any]],
22 | ) -> tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]:
23 | """Select a model among a set of hyperparameters (given or inferred).
24 |
25 | Args:
26 | training_sets: One or more training datasets for the trainer.
27 | validation: A dataset for computing validation metrics.
28 | hyperparameters: A dictionary of hyperparameter choices.
29 |
30 | Return:
31 | A model and tokenizer (with hyperparameters from given range).
32 | """
33 |
--------------------------------------------------------------------------------
/prompt2model/param_selector/mock.py:
--------------------------------------------------------------------------------
1 | """Mock model selector for testing purposes."""
2 | from __future__ import annotations
3 |
4 | from typing import Any
5 |
6 | import datasets
7 | import transformers
8 |
9 | from prompt2model.model_trainer import BaseTrainer
10 | from prompt2model.param_selector.base import ParamSelector
11 | from prompt2model.prompt_parser import PromptSpec
12 |
13 |
14 | class MockParamSelector(ParamSelector):
15 | """Uses a default set of parameters."""
16 |
17 | def __init__(self, trainer: BaseTrainer):
18 | """Initialize with train/val datasets and a prompt specification.
19 |
20 | Args:
21 | trainer: A trainer to use for training models during model selection.
22 | """
23 | self.trainer = trainer
24 |
25 | def _example_hyperparameter_choices(self) -> dict[str, Any]:
26 | """Example hyperparameters (for testing only)."""
27 | return {
28 | "optimizer": "AdamW",
29 | "learning_rate": 1e-4,
30 | }
31 |
32 | def select_from_hyperparameters(
33 | self,
34 | training_sets: list[datasets.Dataset],
35 | validation: datasets.Dataset,
36 | hyperparameters: dict[str, list[Any]],
37 | ) -> tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]:
38 | """Use a pre-defined default set of hyperparameters.
39 |
40 | Args:
41 | training_sets: One or more training datasets for the trainer.
42 | validation: A dataset for computing validation metrics.
43 | prompt_spec: (Optional) A prompt to infer hyperparameters from.
44 | hyperparameters: (Optional) A dictionary of hyperparameter choices.
45 |
46 | Return:
47 | A model and tokenizer (trained using default hyperparameters).
48 | """
49 | single_model = self.trainer.train_model(
50 | self._example_hyperparameter_choices(), training_sets
51 | )
52 | return single_model
53 |
54 | def select_from_spec(
55 | self,
56 | training_sets: list[datasets.Dataset],
57 | validation: datasets.Dataset,
58 | prompt_spec: PromptSpec,
59 | ) -> tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]:
60 | """The MockParamSelector cannot infer hyperparameters from the spec.
61 |
62 | Args:
63 | training_sets: One or more training datasets for the trainer.
64 | validation: A dataset for computing validation metrics.
65 | prompt_spec: (Optional) A prompt to infer hyperparameters from.
66 | hyperparameters: (Optional) A dictionary of hyperparameter choices.
67 | """
68 | raise NotImplementedError
69 |
--------------------------------------------------------------------------------
/prompt2model/param_selector/search_with_optuna.py:
--------------------------------------------------------------------------------
1 | """This module provides automatic hyperparameter selection using Optuna."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | import os
6 | from pathlib import Path
7 | from typing import Any, Optional
8 |
9 | import optuna
10 | import transformers
11 | from datasets import Dataset, concatenate_datasets
12 | from optuna.trial import Trial
13 | from transformers import PreTrainedModel # noqa
14 | from transformers import PreTrainedTokenizer
15 |
16 | from prompt2model.model_trainer.generate import GenerationModelTrainer
17 | from prompt2model.param_selector.base import ParamSelector
18 | from prompt2model.utils.config import DEFAULT_HYPERPARAMETERS_SPACE
19 |
20 |
21 | class OptunaParamSelector(ParamSelector):
22 | """Uses Optuna for searching for hyperparameters."""
23 |
24 | def __init__(self, trainer: GenerationModelTrainer, n_trials: int):
25 | """Initializes a new instance of OptunaParamSelector.
26 |
27 | Args:
28 | trainer (BaseTrainer): trainer object from GenerationModelTrainer
29 | n_trials (int): The maximum number of parameter configurations to evaluate
30 | during conducting hyperparameter search.
31 | """
32 | self.generation_model_trainer = trainer
33 | self.n_trials = n_trials
34 |
35 | def optimize_hyperparameters(
36 | self,
37 | training_datasets: list[Dataset],
38 | validation: Dataset,
39 | hyperparameters: Optional[dict[str, Any]] = None,
40 | ) -> dict[str, Any]:
41 | """Select a model among a set of hyperparameters (given or inferred).
42 |
43 | Args:
44 | training_datasets (list[Dataset]): One or more training datasets
45 | to use for training models.
46 | validation_sets (Dataset): A dataset for computing validation metrics.
47 | hyperparameter_space (Optional[dict[str, Any]], optional): The set
48 | of possible values of hyperparaneters values required for doing
49 | optimal hyperparameter search. Defaults to None.
50 |
51 | Returns:
52 | Returns a dict which contains the best hyperparameters.
53 | """
54 | supported_hp_space_keys = set(DEFAULT_HYPERPARAMETERS_SPACE.keys())
55 | if hyperparameters is not None:
56 | assert set(hyperparameters.keys()).issubset(
57 | supported_hp_space_keys
58 | ), f"Only support {supported_hp_space_keys} as training parameters."
59 | hyperparameter_space = self._build_hp_space(hyperparameters)
60 |
61 | concatenated_training_dataset = concatenate_datasets(training_datasets)
62 | train_dataset = self.generation_model_trainer.tokenize_dataset(
63 | concatenated_training_dataset
64 | )
65 |
66 | if isinstance(validation, list):
67 | validation = concatenate_datasets(validation)
68 | validation = self.generation_model_trainer.tokenize_dataset(validation)
69 |
70 | def objective(trial: Trial) -> float:
71 | model = self.generation_model_trainer.model
72 | training_args = transformers.TrainingArguments(
73 | output_dir="./checkpoint",
74 | learning_rate=trial.suggest_loguniform(
75 | "learning_rate",
76 | low=hyperparameter_space["min_learning_rate"],
77 | high=hyperparameter_space["max_learning_rate"],
78 | ),
79 | weight_decay=trial.suggest_loguniform(
80 | "weight_decay",
81 | low=hyperparameter_space["min_weight_decay"],
82 | high=hyperparameter_space["max_weight_decay"],
83 | ),
84 | num_train_epochs=trial.suggest_int(
85 | "num_train_epochs",
86 | low=hyperparameter_space["min_num_train_epochs"],
87 | high=hyperparameter_space["max_num_train_epochs"],
88 | ),
89 | )
90 | objective_trainer = transformers.Trainer(
91 | model=model,
92 | args=training_args,
93 | data_collator=transformers.DataCollatorForSeq2Seq(
94 | tokenizer=self.generation_model_trainer.tokenizer
95 | ),
96 | train_dataset=train_dataset,
97 | eval_dataset=validation,
98 | )
99 |
100 | _ = objective_trainer.train()
101 | optimization_targets = objective_trainer.evaluate()
102 | return optimization_targets["eval_loss"]
103 |
104 | study = optuna.create_study(
105 | study_name="automatic_hyperparameter_search", direction="minimize"
106 | )
107 |
108 | study.optimize(func=objective, n_trials=self.n_trials, gc_after_trial=True)
109 | best_hyperparameters = {
110 | "learning_rate": float(study.best_params["learning_rate"]),
111 | "weight_decay": float(study.best_params["weight_decay"]),
112 | "num_train_epochs": int(study.best_params["num_train_epochs"]),
113 | }
114 | return best_hyperparameters
115 |
116 | def select_from_hyperparameters(
117 | self,
118 | training_datasets: list[Dataset],
119 | validation: Dataset,
120 | hyperparameters: Optional[dict[str, Any]] = None,
121 | ) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
122 | """Select a model among a set of hyperparameters (given or inferred). # noqa D410
123 |
124 | Args:
125 | training_datasets: One or more training datasets for the trainer.
126 | validation: A dataset for computing validation metrics.
127 | hyperparameters: A dictionary of hyperparameter choices.
128 |
129 | If no hyperparameter_space is specified, then the default hyperparameter_space
130 | will be choosen. Here is the example of how the space looks like:
131 | hyperparameter_space = {
132 | "min_num_train_epochs": 5,
133 | "max_num_train_epochs": 10,
134 | "save_strategy": ["epoch", "steps", "no"],
135 | "evaluation_strategy": ["epoch", "no"],
136 | "per_device_train_batch_size": [4, 8, 16, 32],
137 | "min_weight_decay": 1e-5,
138 | "max_weight_decay": 1e-1,
139 | "min_learning_rate": 1e-5,
140 | "max_learning_rate": 1e-1,
141 | }
142 | Return:
143 | A model and tokenizer (with hyperparameters from given range).
144 | """
145 | model = self.generation_model_trainer.model
146 | tokenizer = self.generation_model_trainer.tokenizer
147 | best_model_path = Path("result/trained_model")
148 | if not os.path.exists(best_model_path):
149 | os.makedirs(best_model_path, exist_ok=True)
150 |
151 | best_hyperparameters = self.optimize_hyperparameters(
152 | training_datasets=training_datasets,
153 | validation=validation,
154 | hyperparameters=hyperparameters,
155 | )
156 | final_hyperparameters = {
157 | "output_dir": "./best_model_checkpoint",
158 | **best_hyperparameters,
159 | }
160 |
161 | model, tokenizer = self.generation_model_trainer.train_model(
162 | hyperparameter_choices=final_hyperparameters,
163 | training_datasets=training_datasets,
164 | )
165 |
166 | model.save_pretrained(best_model_path)
167 | tokenizer.save_pretrained(best_model_path)
168 | return model, tokenizer
169 |
170 | def _build_hp_space(
171 | self, hyperparameter_space: Optional[dict[str, Any]] = None
172 | ) -> dict[str, Any]:
173 | if hyperparameter_space is None:
174 | return DEFAULT_HYPERPARAMETERS_SPACE
175 | hp_space = {}
176 |
177 | default_keys = list(DEFAULT_HYPERPARAMETERS_SPACE.keys())
178 | for key in list(hyperparameter_space.keys()):
179 | if key not in default_keys:
180 | print(
181 | f"Key {key} is not present in DEFAULT_HYPERPARAMETERS_SPACE. Hence, it will be ignored.", # noqa E501
182 | "However, you can expose the key to the Trainer by adding it to DEFAULT_HYPERPARAMETERS_SPACE.", # noqa E501
183 | )
184 |
185 | for key, default_value in DEFAULT_HYPERPARAMETERS_SPACE.items():
186 | hp_space[key] = hyperparameter_space.get(key, default_value)
187 | return hp_space
188 |
--------------------------------------------------------------------------------
/prompt2model/prompt_parser/__init__.py:
--------------------------------------------------------------------------------
1 | """Import PromptSpec classes."""
2 | from prompt2model.prompt_parser.base import PromptSpec, TaskType
3 | from prompt2model.prompt_parser.instr_parser import PromptBasedInstructionParser
4 | from prompt2model.prompt_parser.mock import MockPromptSpec
5 |
6 | __all__ = (
7 | "PromptSpec",
8 | "TaskType",
9 | "MockPromptSpec",
10 | "PromptBasedInstructionParser",
11 | )
12 |
--------------------------------------------------------------------------------
/prompt2model/prompt_parser/base.py:
--------------------------------------------------------------------------------
1 | """An interface for prompt parsing."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | from abc import ABC, abstractmethod
6 | from enum import Enum
7 |
8 |
9 | class TaskType(Enum):
10 | """High-level taxonomy of possible NLP model outputs."""
11 |
12 | TEXT_GENERATION = 1
13 | CLASSIFICATION = 2
14 | SEQUENCE_TAGGING = 3
15 | SPAN_EXTRACTION = 4
16 |
17 |
18 | class PromptSpec(ABC):
19 | """Parse and store structured information about the prompt."""
20 |
21 | task_type: TaskType
22 | _instruction: str | None
23 | _examples: str | None
24 |
25 | @abstractmethod
26 | def parse_from_prompt(self, prompt: str) -> None:
27 | """Populate this class by parsing a prompt."""
28 |
29 | @property
30 | def instruction(self) -> str:
31 | """Return the natural language instruction parsed from the prompt."""
32 | if self._instruction is None:
33 | raise ValueError("Instruction hasn't been parsed from the prompt.")
34 | return self._instruction
35 |
36 | @property
37 | def examples(self) -> str:
38 | """Return the natural language examples parsed from the prompt."""
39 | return self._examples or ""
40 |
--------------------------------------------------------------------------------
/prompt2model/prompt_parser/instr_parser.py:
--------------------------------------------------------------------------------
1 | """An interface for prompt parsing."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | import os
6 |
7 | from prompt2model.prompt_parser.base import PromptSpec, TaskType
8 |
9 | from prompt2model.prompt_parser.instr_parser_prompt import ( # isort: split
10 | construct_prompt_for_instruction_parsing,
11 | )
12 |
13 | from prompt2model.utils.parse_responses import parse_prompt_to_fields
14 |
15 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
16 |
17 |
18 | class PromptBasedInstructionParser(PromptSpec):
19 | """Parse the prompt to separate instructions from task demonstrations."""
20 |
21 | def __init__(self, task_type: TaskType, max_api_calls: int = 5):
22 | """Initialize the prompt spec with empty parsed fields.
23 |
24 | We initialize the "instruction" and "examples" fields with None.
25 | These fields can be populated with the parse_from_prompt method.
26 |
27 | Args:
28 | task_type: Set a constant task type to use for all prompts.
29 | max_api_calls: The maximum number of API calls allowed,
30 | or None for unlimited.
31 | """
32 | self.task_type = task_type
33 | self._instruction: str | None = None
34 | self._examples: str | None = None
35 | self.max_api_calls = max_api_calls
36 |
37 | def parse_from_prompt(self, prompt: str) -> None:
38 | """Parse prompt into specific fields, stored as class member variables.
39 |
40 | This function directly stores the parsed fields into the class's member
41 | variables `instruction` and `examples`. So it has no return value.
42 |
43 | Args:
44 | prompt: User prompt to parse into two specific fields:
45 | "instruction" and "demonstrations".
46 | """
47 | parsing_prompt_for_chatgpt = construct_prompt_for_instruction_parsing(prompt)
48 | required_keys = ["Instruction", "Demonstrations"]
49 |
50 | extraction = parse_prompt_to_fields(
51 | parsing_prompt_for_chatgpt,
52 | required_keys,
53 | max_api_calls=self.max_api_calls,
54 | )
55 | self._instruction = extraction["Instruction"]
56 | self._examples = extraction["Demonstrations"]
57 |
58 | def set_instruction_and_examples(
59 | self, instruction: str = "", examples: str = ""
60 | ) -> None:
61 | """Set the instruction and examples directly."""
62 | self._instruction = instruction
63 | self._examples = examples
64 |
--------------------------------------------------------------------------------
/prompt2model/prompt_parser/mock.py:
--------------------------------------------------------------------------------
1 | """An interface for prompt parsing."""
2 |
3 | from prompt2model.prompt_parser.base import PromptSpec, TaskType
4 |
5 |
6 | class MockPromptSpec(PromptSpec):
7 | """Mock the bebavior of PromptSpec."""
8 |
9 | def __init__(
10 | self, task_type: TaskType, instruction: str = None, examples: str = None
11 | ):
12 | """Mock the elements of PromptSpec."""
13 | self.task_type = task_type
14 | if instruction is None:
15 | self._instruction = (
16 | "Give me some translation from Chinese to English."
17 | " Input Chinese and output English."
18 | )
19 | else:
20 | self._instruction = instruction
21 | if examples is None:
22 | self._examples = (
23 | "input: '人生苦短,我用 Python', output: 'Life is short, I use Python. '"
24 | "input: '明天是周末', output: 'Tomorrow is weekend.'"
25 | )
26 | else:
27 | self._examples = examples
28 |
29 | def parse_from_prompt(self, prompt: str) -> None:
30 | """Don't parse anything."""
31 | self._instruction = prompt
32 | return None
33 |
--------------------------------------------------------------------------------
/prompt2model/prompt_parser/readme.md:
--------------------------------------------------------------------------------
1 | # Prompt Parser
2 |
3 | ## Overview
4 |
5 | - `PromptSpec`: Interface for parsing prompts into instructions and
6 | examples.
7 | - `TaskType`: Enum for classifying NLP tasks like text generation,
8 | classification, etc.
9 | - `PromptBasedInstructionParser`: A `PromptSpec` subclass that uses GPT-3.5
10 | API for parsing.
11 |
12 | ## Getting Started
13 |
14 | - Import Modules:
15 |
16 | ```python
17 | from prompt2model.prompt_parser import PromptBasedInstructionParser, TaskType
18 | ```
19 |
20 | - Setup API Key and Initialize Parser. For instance, if using OpenAI:
21 |
22 | ```bash
23 | export OPENAI_API_KEY=""
24 | ```
25 |
26 | And then initialize the Parser:
27 |
28 | ```python
29 | task_type = TaskType.
30 | prompt_spec = PromptBasedInstructionParser(task_type)
31 | ```
32 |
33 | ### Parse the Prompt
34 |
35 | ```python
36 | prompt = ""
37 | prompt_spec.parse_from_prompt(prompt)
38 | ```
39 |
40 | ### Access Parsed Fields
41 |
42 | ```python
43 | instruction = prompt_spec.get_instruction # Retrieves parsed instruction.
44 | demonstrations = prompt_spec.get_examples # Retrieves parsed examples.
45 | ```
46 |
47 | ### Mock
48 |
49 | If you want to mock a `PromptSpec` object without parsing from a prompt,
50 | you can use the `MockPromptSpec` class.
51 |
52 | ```python
53 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION)
54 | instruction = """...""" # A string indicating the task description.
55 | examples = """...""" # A string indicating the examples.
56 | prompt_spec._instruction = prompt
57 | prompt_spec._examples = examples
58 | ```
59 |
--------------------------------------------------------------------------------
/prompt2model/run_locally.py:
--------------------------------------------------------------------------------
1 | """A script to run the prompt2model pipeline locally."""
2 | from __future__ import annotations
3 |
4 | import argparse
5 |
6 | from prompt2model.dataset_generator import DatasetSplit, MockDatasetGenerator
7 | from prompt2model.dataset_processor import MockProcessor
8 | from prompt2model.dataset_retriever import MockRetriever
9 | from prompt2model.demo_creator import mock_gradio_create
10 | from prompt2model.model_evaluator import MockEvaluator
11 | from prompt2model.model_executor import MockModelExecutor
12 | from prompt2model.model_retriever import MockModelRetriever
13 | from prompt2model.model_trainer import MockTrainer
14 | from prompt2model.param_selector import MockParamSelector
15 | from prompt2model.prompt_parser import MockPromptSpec, PromptSpec, TaskType
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | "--prompt",
20 | type=str,
21 | nargs="+",
22 | required=True,
23 | help="Prompt (with optional few-shot examples) for language model",
24 | )
25 | parser.add_argument(
26 | "--metrics-output-path",
27 | type=str,
28 | help="Path to JSON file where we store model metrics",
29 | default="/tmp/metrics.json",
30 | )
31 |
32 |
33 | def process_input_prompt(prompt_tokens: list[str]) -> PromptSpec:
34 | """Preprocess the input prompt given by the user and parse.
35 |
36 | Args:
37 | prompt_tokens: Tokens in the prompt.
38 |
39 | Returns:
40 | A PromptSpec parsed from the processed prompt tokens.
41 |
42 | """
43 | prompt_str = " ".join(prompt_tokens).strip()
44 | start_quotations_present = False
45 | end_quotations_present = False
46 | quotation_marks = ['"', "“", "‟", "”"]
47 | for start_quote in quotation_marks:
48 | if prompt_str.startswith(start_quote):
49 | start_quotations_present = True
50 | break
51 | for end_quote in quotation_marks:
52 | if prompt_str.endswith(end_quote):
53 | end_quotations_present = True
54 | break
55 | if start_quotations_present and end_quotations_present:
56 | prompt_str = prompt_str[1:-1]
57 |
58 | prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION)
59 | prompt_spec.parse_from_prompt(prompt_str)
60 | return prompt_spec
61 |
62 |
63 | def run_skeleton(prompt_tokens: list[str], metrics_output_path: str) -> None:
64 | """Run the prompt2model pipeline locally using base/stub components."""
65 | prompt_spec = process_input_prompt(prompt_tokens)
66 |
67 | # Retrieve and generate datasets
68 | retriever = MockRetriever()
69 | retrieved_dataset_dicts = retriever.retrieve_dataset_dict(prompt_spec)
70 |
71 | generator = MockDatasetGenerator()
72 | expected_num_examples = {
73 | DatasetSplit.TRAIN: 40,
74 | DatasetSplit.VAL: 5,
75 | DatasetSplit.TEST: 5,
76 | }
77 | generated_dataset_dicts = generator.generate_dataset_dict(
78 | prompt_spec, expected_num_examples
79 | )
80 |
81 | processor = MockProcessor(has_encoder=True, eos_token="")
82 | retrieved_dataset_dicts, generated_dataset_dicts = processor.process_dataset_dict(
83 | instruction="", dataset_dicts=[retrieved_dataset_dicts, generated_dataset_dicts]
84 | )
85 |
86 | retrieved_training = [
87 | dataset_dict["train"] for dataset_dict in retrieved_dataset_dicts
88 | ]
89 |
90 | generated_training = generated_dataset_dicts[DatasetSplit.TRAIN.value]
91 | validation = generated_dataset_dicts[DatasetSplit.VAL.value]
92 | testing = generated_dataset_dicts[DatasetSplit.TEST.value]
93 | all_training = retrieved_training + [generated_training]
94 |
95 | model_retriever = MockModelRetriever("cardiffnlp/twitter-roberta-base-sentiment")
96 | retrieved_model_name = model_retriever.retrieve(prompt_spec)
97 |
98 | trainer = MockTrainer(retrieved_model_name[0])
99 | selector = MockParamSelector(trainer)
100 | model, tokenizer = selector.select_from_hyperparameters(
101 | all_training, validation, {}
102 | )
103 |
104 | model_executor = MockModelExecutor(model, tokenizer)
105 | predictions = model_executor.make_prediction(testing, "input_col")
106 |
107 | evaluator = MockEvaluator()
108 | metrics_dict = evaluator.evaluate_model(
109 | testing, "output_col", predictions, "input_col", []
110 | )
111 | evaluator.write_metrics(metrics_dict, metrics_output_path)
112 | mock_gradio_create(model_executor, prompt_spec)
113 |
114 |
115 | if __name__ == "__main__":
116 | args = parser.parse_args()
117 | run_skeleton(args.prompt, args.metrics_output_path)
118 |
--------------------------------------------------------------------------------
/prompt2model/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Import utility functions."""
2 | from prompt2model.utils.api_tools import (
3 | API_ERRORS,
4 | APIAgent,
5 | count_tokens_from_string,
6 | handle_api_error,
7 | )
8 | from prompt2model.utils.logging_utils import get_formatted_logger
9 | from prompt2model.utils.rng import seed_generator
10 | from prompt2model.utils.tevatron_utils import encode_text, retrieve_objects
11 |
12 | __all__ = ( # noqa: F401
13 | "APIAgent",
14 | "encode_text",
15 | "handle_api_error",
16 | "API_ERRORS",
17 | "retrieve_objects",
18 | "seed_generator",
19 | "count_tokens_from_string",
20 | "get_formatted_logger",
21 | )
22 |
--------------------------------------------------------------------------------
/prompt2model/utils/config.py:
--------------------------------------------------------------------------------
1 | """Place to store all the default configurations."""
2 | MAX_SUPPORTED_BATCH_SIZE = 4
3 |
4 | DEFAULT_HYPERPARAMETERS_SPACE = {
5 | "min_num_train_epochs": 5,
6 | "max_num_train_epochs": 15,
7 | "save_strategy": ["no"],
8 | "evaluation_strategy": ["no"],
9 | "per_device_train_batch_size": MAX_SUPPORTED_BATCH_SIZE,
10 | "min_weight_decay": 4e-5,
11 | "max_weight_decay": 1e-1,
12 | "min_learning_rate": 4e-5,
13 | "max_learning_rate": 1e-1,
14 | }
15 |
--------------------------------------------------------------------------------
/prompt2model/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | """Util functions for datasets."""
2 |
3 |
4 | import datasets
5 | import requests
6 |
7 | from prompt2model.utils.logging_utils import get_formatted_logger
8 |
9 | logger = get_formatted_logger("dataset_utils")
10 |
11 |
12 | def query(API_URL):
13 | """Returns a response json for a URL."""
14 | try:
15 | response = requests.get(API_URL)
16 | if response.status_code == 200:
17 | return response.json()
18 | else:
19 | logger.error(f"Error occurred in fetching size: {response.status_code}")
20 | except requests.exceptions.RequestException as e:
21 | logger.error("Error occurred in making the request: " + str(e))
22 |
23 | return {}
24 |
25 |
26 | def get_dataset_size(dataset_name):
27 | """Fetches dataset size for a dataset in MB from hugging face API."""
28 | API_URL = f"https://datasets-server.huggingface.co/size?dataset={dataset_name}"
29 | data = query(API_URL)
30 | size_dict = data.get("size", {})
31 | return (
32 | "NA"
33 | if size_dict == {}
34 | else "{:.2f}".format(size_dict["dataset"]["num_bytes_memory"] / 1024 / 1024)
35 | )
36 |
37 |
38 | def make_combined_datasets(
39 | dataset_list: list[datasets.Dataset], dataset_type: str = "input_output"
40 | ) -> datasets.Dataset:
41 | """Combine multiple datasets into one.
42 |
43 | Args:
44 | dataset_list: List of datasets to combine.
45 | dataset_type: Type of dataset to combine. Can be "text" or "input_output".
46 | "text" is for combining datasets with a single column "text".
47 | "input_output" is for combining datasets with 2 columns "input_col"
48 | and "output_col".
49 |
50 | Returns:
51 | A combined dataset.
52 | Singe column "text" if dataset_type is "text".
53 | Two columns "input_col" and "output_col" if dataset_type is "input_output".
54 | ValueError if dataset_type is not "text" or "input_output".
55 | """
56 | if dataset_type == "text":
57 | text_col = []
58 | for dataset in dataset_list:
59 | text_col.extend(dataset["text"])
60 | return datasets.Dataset.from_dict({"text": text_col})
61 | elif dataset_type == "input_output":
62 | input_col = []
63 | output_col = []
64 | for dataset in dataset_list:
65 | input_col.extend(dataset["input_col"])
66 | output_col.extend(dataset["output_col"])
67 |
68 | dataset = datasets.Dataset.from_dict(
69 | {"input_col": input_col, "output_col": output_col}
70 | )
71 | return dataset
72 | else:
73 | raise ValueError(
74 | f"dataset_type can be either 'text' or 'input_output' but got {dataset_type}" # noqa E501
75 | )
76 |
77 |
78 | def format_train_data(train_dataset: datasets.Dataset):
79 | """Formats the train dataset for training."""
80 | final_texts = []
81 | for row in train_dataset:
82 | final_texts.append(f"{row['input_col'].strip()} {row['output_col'].strip()}")
83 | return datasets.Dataset.from_dict({"text": final_texts})
84 |
--------------------------------------------------------------------------------
/prompt2model/utils/dataset_utils_test.py:
--------------------------------------------------------------------------------
1 | """Testing dataset utility functions."""
2 | from unittest.mock import patch
3 |
4 | from prompt2model.utils import dataset_utils
5 |
6 |
7 | @patch("prompt2model.utils.dataset_utils.query")
8 | def test_get_dataset_size(mock_request):
9 | """Test function for get_dataset_size."""
10 | mock_request.return_value = {
11 | "size": {
12 | "dataset": {
13 | "dataset": "rotten_tomatoes",
14 | "num_bytes_original_files": 487770,
15 | "num_bytes_parquet_files": 881052,
16 | "num_bytes_memory": 1345449,
17 | "num_rows": 10662,
18 | },
19 | "configs": [
20 | {
21 | "dataset": "rotten_tomatoes",
22 | "config": "default",
23 | "num_bytes_original_files": 487770,
24 | "num_bytes_parquet_files": 881052,
25 | "num_bytes_memory": 1345449,
26 | "num_rows": 10662,
27 | "num_columns": 2,
28 | }
29 | ],
30 | "splits": [
31 | {
32 | "dataset": "rotten_tomatoes",
33 | "config": "default",
34 | "split": "train",
35 | "num_bytes_parquet_files": 698845,
36 | "num_bytes_memory": 1074806,
37 | "num_rows": 8530,
38 | "num_columns": 2,
39 | },
40 | {
41 | "dataset": "rotten_tomatoes",
42 | "config": "default",
43 | "split": "validation",
44 | "num_bytes_parquet_files": 90001,
45 | "num_bytes_memory": 134675,
46 | "num_rows": 1066,
47 | "num_columns": 2,
48 | },
49 | {
50 | "dataset": "rotten_tomatoes",
51 | "config": "default",
52 | "split": "test",
53 | "num_bytes_parquet_files": 92206,
54 | "num_bytes_memory": 135968,
55 | "num_rows": 1066,
56 | "num_columns": 2,
57 | },
58 | ],
59 | },
60 | "pending": [],
61 | "failed": [],
62 | "partial": False,
63 | }
64 | assert dataset_utils.get_dataset_size("rotten_tomatoes") == "1.28"
65 |
--------------------------------------------------------------------------------
/prompt2model/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | """Utils for creating formatted logger."""
2 |
3 | import logging
4 |
5 |
6 | def get_formatted_logger(logger_name: str):
7 | """Create a formatted logger.
8 |
9 | Args:
10 | logger_name: The name of the logger, usually the name
11 | of the component that uses the logger.
12 |
13 | Returns:
14 | A logger object.
15 | """
16 | logger = logging.getLogger(logger_name)
17 | # Check if the logger already has a StreamHandler to prevent adding another one.
18 | if not any(
19 | isinstance(handler, logging.StreamHandler) for handler in logger.handlers
20 | ):
21 | ch = logging.StreamHandler()
22 | formatter = logging.Formatter(
23 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
24 | )
25 | ch.setFormatter(formatter)
26 | logger.addHandler(ch)
27 | return logger
28 |
--------------------------------------------------------------------------------
/prompt2model/utils/parse_responses.py:
--------------------------------------------------------------------------------
1 | """Utility file for parsing OpenAI json responses."""
2 | from __future__ import annotations
3 |
4 | import json
5 | import re
6 | from typing import Any
7 |
8 | import openai
9 |
10 | from prompt2model.utils import api_tools, get_formatted_logger
11 | from prompt2model.utils.api_tools import API_ERRORS, handle_api_error
12 |
13 | logger = get_formatted_logger("ParseJsonResponses")
14 |
15 |
16 | def find_and_parse_json(
17 | response: openai.Completion, required_keys: list, optional_keys: list = []
18 | ) -> dict | None:
19 | """Parse stuctured fields from the API response.
20 |
21 | In case there are multiple JSON objects in the response, take the final one.
22 |
23 | Args:
24 | response: API response.
25 | required_keys: Required keys from the response
26 | optional_keys: Optional keys from the response
27 |
28 | Returns:
29 | If the API response is a valid JSON object and contains the
30 | required and optional keys then returns the
31 | final response as a Dictionary
32 | Else returns None.
33 | """
34 | if not isinstance(response, str) and hasattr(response, "choices"):
35 | response = response.choices[0]["message"]["content"]
36 | correct_json = find_rightmost_brackets(response)
37 |
38 | if correct_json is None:
39 | logger.warning("No valid JSON found in the response.")
40 | return None
41 |
42 | try:
43 | response_json = json.loads(correct_json, strict=False)
44 | except json.decoder.JSONDecodeError:
45 | logger.warning(f"API response was not a valid JSON: {correct_json}")
46 | return None
47 |
48 | missing_keys = [key for key in required_keys if key not in response_json]
49 | if len(missing_keys) != 0:
50 | logger.warning(f'API response must contain {", ".join(required_keys)} keys')
51 | return None
52 |
53 | final_response = {}
54 | for key in required_keys + optional_keys:
55 | if key not in response_json:
56 | # This is an optional key, so exclude it from the final response.
57 | continue
58 | if type(response_json[key]) == str:
59 | final_response[key] = response_json[key].strip()
60 | else:
61 | final_response[key] = response_json[key]
62 | return final_response
63 |
64 |
65 | def find_rightmost_brackets(text: str) -> str | None:
66 | """Find the rightmost complete set of brackets in a string."""
67 | stack = []
68 | for i, char in enumerate(reversed(text)):
69 | if char == "}":
70 | stack.append(len(text) - i - 1)
71 | elif char == "{" and stack:
72 | start = len(text) - i - 1
73 | end = stack.pop()
74 | if not stack: # Found the rightmost complete set
75 | return text[start : end + 1]
76 | return None
77 |
78 |
79 | def parse_dataset_config_responses(response: openai.ChatCompletion) -> dict:
80 | """Parse the response to extract relevant information from dataset/configuration.
81 |
82 | LLMs can return the dataset configuration in different formats -
83 | usually either between ** ** or as a sentence.
84 |
85 | Args:
86 | response: The response containing the dataset configuration.
87 |
88 | Returns:
89 | The extracted relevant information from the dataset configuration.
90 | """
91 | if not isinstance(response, str) and hasattr(response, "choices"):
92 | response_str = response.choices[0]["message"]["content"]
93 | else:
94 | response_str = response
95 |
96 | pattern = r"\*\*(.*?)\*\*"
97 |
98 | match = re.search(pattern, response_str)
99 | dataset_config = ""
100 | if match:
101 | dataset_config = match.group(1)
102 | elif len(response_str.split()) >= 1:
103 | dataset_config = response_str.split()[-1].replace(".", "")
104 | return {"name": dataset_config}
105 |
106 |
107 | def parse_prompt_to_fields(
108 | prompt: str,
109 | required_keys: list = [],
110 | optional_keys: list = [],
111 | max_api_calls: int = 5,
112 | module_name: str = "col_selection",
113 | ) -> dict[str, Any]:
114 | """Parse prompt into specific fields, and return to the calling function.
115 |
116 | This function calls the required api, has the logic for the retrying,
117 | passes the response to the parsing function, and return the
118 | response back or throws an error
119 |
120 | Args:
121 | prompt: User prompt into specific fields
122 | required_keys: Fields that need to be present in the response
123 | optional_keys: Field that may/may not be present in the response
124 | max_api_calls: Max number of retries, defaults to 5 to avoid
125 | being stuck in an infinite loop
126 | module_name: The module this is to be used for. Currently supports
127 | rerank and col_selection
128 |
129 | Returns:
130 | Parsed Response as a dictionary.
131 |
132 | Raises:
133 | ValueError: If max_api_calls is not greater than 0.
134 | RuntimeError: If the maximum number of API calls is reached.
135 |
136 | """
137 | chat_api = api_tools.default_api_agent
138 | if max_api_calls <= 0:
139 | raise ValueError("max_api_calls must be > 0.")
140 |
141 | api_call_counter = 0
142 | last_error = None
143 | while True:
144 | api_call_counter += 1
145 | try:
146 | response: openai.ChatCompletion | Exception = (
147 | chat_api.generate_one_completion(
148 | prompt,
149 | temperature=0.01,
150 | presence_penalty=0,
151 | frequency_penalty=0,
152 | )
153 | )
154 | extraction: dict[str, Any] | None = None
155 | if module_name == "col_selection":
156 | extraction = find_and_parse_json(response, required_keys, optional_keys)
157 |
158 | elif module_name == "rerank":
159 | extraction = parse_dataset_config_responses(response)
160 | if extraction is not None:
161 | return extraction
162 | except API_ERRORS as e:
163 | last_error = e
164 | handle_api_error(e, backoff_duration=2**api_call_counter)
165 |
166 | if api_call_counter >= max_api_calls:
167 | # In case we reach maximum number of API calls, we raise an error.
168 | logger.error("Maximum number of API calls reached.")
169 | raise RuntimeError("Maximum number of API calls reached.") from last_error
170 |
171 |
172 | def make_single_api_request(prompt: str, max_api_calls: int = 10) -> str:
173 | """Prompts an LLM using the APIAgent, and returns the response.
174 |
175 | This function calls the required api, has the logic for retrying,
176 | returns the response back or throws an error
177 | Args:
178 | prompt: User prompt into specific fields
179 | max_api_calls: Max number of retries, defaults to 5 to avoid
180 | being stuck in an infinite loop
181 | Returns:
182 | Response text or throws error
183 | """
184 | chat_api = api_tools.default_api_agent
185 | if max_api_calls <= 0:
186 | raise ValueError("max_api_calls must be > 0.")
187 |
188 | api_call_counter = 0
189 | last_error = None
190 | while True:
191 | api_call_counter += 1
192 | try:
193 | response: openai.ChatCompletion = chat_api.generate_one_completion(
194 | prompt=prompt, temperature=0.01, presence_penalty=0, frequency_penalty=0
195 | )
196 | if response is not None:
197 | return response.choices[0]["message"]["content"]
198 |
199 | except API_ERRORS as e:
200 | last_error = e
201 | handle_api_error(e, backoff_duration=2**api_call_counter)
202 |
203 | if api_call_counter >= max_api_calls:
204 | # In case we reach maximum number of API calls, we raise an error.
205 | logger.error("Maximum number of API calls reached.")
206 | raise RuntimeError("Maximum number of API calls reached.") from last_error
207 |
--------------------------------------------------------------------------------
/prompt2model/utils/retrieve_model_info.py:
--------------------------------------------------------------------------------
1 | """Retrieve HuggingFace model's size, description, and downloads."""
2 | import json
3 | import os
4 | import re
5 | import subprocess
6 | from pathlib import Path
7 |
8 | from huggingface_hub import HfApi
9 |
10 |
11 | def main(pretrained_model_name: str, cache_dir: str = None) -> None:
12 | """Downloads and caches a Hugging Face model's metadata.
13 |
14 | Args:
15 | pretrained_model_name: HuggingFace pretrained_model_name.
16 | cache_dir: A directory to cache the metadata.
17 | """
18 | if cache_dir is None:
19 | cache_dir = "model_info"
20 | cache_path = Path.cwd() / cache_dir
21 | cache_path.mkdir(parents=True, exist_ok=True)
22 |
23 | if len(pretrained_model_name.split("/")) == 2:
24 | _, model_name = pretrained_model_name.split("/")
25 | else:
26 | model_name = pretrained_model_name
27 |
28 | subprocess.run(
29 | ["git", "clone", f"https://huggingface.co/{pretrained_model_name}"],
30 | env=dict(os.environ, GIT_LFS_SKIP_SMUDGE="1"),
31 | stdout=subprocess.DEVNULL,
32 | stderr=subprocess.DEVNULL,
33 | )
34 |
35 | model_bin_file = Path(f"{model_name}/pytorch_model.bin")
36 |
37 | try:
38 | if model_bin_file.exists():
39 | with open(model_bin_file, "r") as file:
40 | content = file.read()
41 | size = re.search(r"size (\d+)", content).group(1) # type: ignore
42 | print(size)
43 | else:
44 | model_bin_file = Path(f"{model_name}/pytorch_model.bin.index.json")
45 | with open(model_bin_file, "r") as file:
46 | content = json.loads(file.read()) # type: ignore
47 | size = content["metadata"]["total_size"] # type: ignore
48 | print(size)
49 | except (
50 | FileNotFoundError,
51 | PermissionError,
52 | IOError,
53 | json.decoder.JSONDecodeError,
54 | ) as e:
55 | raise Exception(f"Failed to read {model_name} in {model_bin_file}: {e}")
56 |
57 | with open(f"{model_name}/README.md", "r", encoding="utf-8") as f:
58 | readme_content = f.read()
59 | print(readme_content)
60 |
61 | api = HfApi()
62 | model_meta = api.model_info(pretrained_model_name)
63 | downloads = model_meta.downloads
64 | print(pretrained_model_name, downloads)
65 |
66 | model_info = {
67 | "pretrained_model_name": pretrained_model_name,
68 | "description": readme_content,
69 | "size_bytes": size,
70 | "downloads": downloads,
71 | }
72 | model_info_path = Path(f"{cache_dir}/{model_name}.json")
73 | model_info_path.touch()
74 | with open(model_info_path, "w") as file:
75 | file.write(json.dumps(model_info))
76 |
77 | # The model must exist because it can be found by
78 | # `model_meta = api.model_info(pretrained_model_name)`
79 | subprocess.run(["rm", "-rf", model_name])
80 |
81 |
82 | if __name__ == "__main__":
83 | main("facebook/roscoe-512-roberta-base")
84 | main("gpt2")
85 |
--------------------------------------------------------------------------------
/prompt2model/utils/rng.py:
--------------------------------------------------------------------------------
1 | """Classes for setting random seeds."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | from abc import ABC, abstractmethod
6 |
7 |
8 | class SeedGenerator(ABC):
9 | """Select a good model from among a set of hyperparameter choices."""
10 |
11 | @abstractmethod
12 | def get_seed(self) -> int:
13 | """Generate a random seed."""
14 |
15 |
16 | class ConstantSeedGenerator(SeedGenerator):
17 | """A seed generator that always returns the same seed."""
18 |
19 | def __init__(self, seed: int = 2023):
20 | """Initialize with a constant seed (by default, 2023)."""
21 | self.seed = seed
22 |
23 | def get_seed(self) -> int:
24 | """Return a constant random seed."""
25 | return self.seed
26 |
27 |
28 | seed_generator = ConstantSeedGenerator()
29 |
--------------------------------------------------------------------------------
/prompt2model/utils/tevatron_utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Import Tevatron utility functions."""
2 | from prompt2model.utils.tevatron_utils.encode import encode_text
3 | from prompt2model.utils.tevatron_utils.retrieve import retrieve_objects
4 |
5 | __all__ = ("encode_text", "retrieve_objects")
6 |
--------------------------------------------------------------------------------
/prompt2model/utils/tevatron_utils/encode.py:
--------------------------------------------------------------------------------
1 | """Tools for encoding and serializing a search index with a contextual encoder."""
2 |
3 | from __future__ import annotations # noqa FI58
4 |
5 | import json
6 | import os
7 | import pickle
8 | import tempfile
9 | from contextlib import nullcontext
10 |
11 | import numpy as np
12 | import torch
13 | from tevatron.arguments import DataArguments
14 | from tevatron.data import EncodeCollator, EncodeDataset
15 | from tevatron.datasets import HFCorpusDataset, HFQueryDataset
16 | from tevatron.modeling import DenseModelForInference
17 | from torch.utils.data import DataLoader
18 | from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
19 |
20 |
21 | def load_tevatron_model(
22 | model_name_or_path: str, model_cache_dir: str | None = None
23 | ) -> tuple[DenseModelForInference, PreTrainedTokenizerBase]:
24 | """Load a Tevatron model from a model name/path.
25 |
26 | Args:
27 | model_name_or_path: The HuggingFace model name or path to the model.
28 | model_cache_dir: The directory to cache the model.
29 |
30 | Returns:
31 | A Tevatron dense retrieval model and its associated tokenizer.
32 | """
33 | config = AutoConfig.from_pretrained(
34 | model_name_or_path,
35 | cache_dir=model_cache_dir,
36 | )
37 | tokenizer = AutoTokenizer.from_pretrained(
38 | model_name_or_path,
39 | cache_dir=model_cache_dir,
40 | use_fast=False,
41 | )
42 | model = DenseModelForInference.build(
43 | model_name_or_path=model_name_or_path,
44 | config=config,
45 | cache_dir=model_cache_dir,
46 | )
47 | return model, tokenizer
48 |
49 |
50 | def encode_text(
51 | model_name_or_path: str,
52 | file_to_encode: str | None = None,
53 | text_to_encode: list[str] | str | None = None,
54 | encode_query: bool = False,
55 | encoding_file: str | None = None,
56 | max_len: int = 400,
57 | device: torch.device = torch.device("cpu"),
58 | dataloader_num_workers: int = 0,
59 | model_cache_dir: str | None = None,
60 | data_cache_dir: str = "~/.cache/huggingface/datasets",
61 | batch_size=8,
62 | fp16: bool = False,
63 | ) -> np.ndarray:
64 | """Encode a query or documents.
65 |
66 | This code is mostly duplicated from tevatron/driver/encode.py in the Tevatron
67 | repository.
68 |
69 | Args:
70 | model_name_or_path: The HuggingFace model name or path to the model.
71 | file_to_encode: JSON or JSONL file containing `"text"` fields to encode.
72 | text_to_encode: String or list of strings to encode.
73 | encode_query: Whether or not we are encoding a query or documents.
74 | encoding_file: If given, store the encoded data in this file.
75 | max_len: Truncate the input to this length (in tokens).
76 | device: Device that Torch will use to encode the text.
77 | dataloader_num_workers: Number of workers to use for the dataloader.
78 | model_cache_dir: The directory to cache the model.
79 | data_cache_dir: The directory to cache the tokenized dataset.
80 | batch_size: Batch size to use for encoding.
81 | fp16: Whether or not to run inference in fp16 for more-efficient encoding.
82 |
83 | Returns:
84 | A numpy array of shape `(expected_num_examples, embedding_dim)` containing text
85 | encoded by the specified model.
86 | """
87 | model, tokenizer = load_tevatron_model(model_name_or_path, model_cache_dir)
88 |
89 | if file_to_encode is None and text_to_encode is None:
90 | raise ValueError("Must provide either a dataset file or text to encode.")
91 | elif file_to_encode is not None and text_to_encode is not None:
92 | raise ValueError("Provide either a dataset file or text to encode, not both.")
93 |
94 | with tempfile.TemporaryDirectory() as temp_dir:
95 | if text_to_encode is not None:
96 | if isinstance(text_to_encode, str):
97 | text_to_encode = [text_to_encode]
98 | with open(
99 | os.path.join(temp_dir, "text_to_encode.json"), "w"
100 | ) as temporary_file:
101 | text_rows = [
102 | {"text_id": i, "text": text}
103 | for i, text in enumerate(text_to_encode)
104 | ]
105 | json.dump(text_rows, temporary_file)
106 | file_to_encode = temporary_file.name
107 | temporary_file.close()
108 |
109 | data_args = DataArguments(
110 | encoded_save_path=encoding_file,
111 | encode_in_path=file_to_encode,
112 | encode_is_qry=encode_query,
113 | data_cache_dir=data_cache_dir,
114 | )
115 | if encode_query:
116 | data_args.q_max_len = max_len
117 | hf_dataset = HFQueryDataset(
118 | tokenizer=tokenizer,
119 | data_args=data_args,
120 | cache_dir=data_args.data_cache_dir or model_cache_dir,
121 | )
122 | else:
123 | data_args.p_max_len = max_len
124 | hf_dataset = HFCorpusDataset(
125 | tokenizer=tokenizer,
126 | data_args=data_args,
127 | cache_dir=data_args.data_cache_dir or model_cache_dir,
128 | )
129 |
130 | encode_dataset = EncodeDataset(
131 | hf_dataset.process(1, 0), tokenizer, max_len=max_len
132 | )
133 |
134 | encode_loader = DataLoader(
135 | encode_dataset,
136 | batch_size=batch_size,
137 | collate_fn=EncodeCollator(
138 | tokenizer, max_length=max_len, padding="max_length"
139 | ),
140 | shuffle=False,
141 | drop_last=False,
142 | num_workers=dataloader_num_workers,
143 | )
144 | encoded = []
145 | lookup_indices = []
146 | model = model.to(device)
147 | model.eval()
148 |
149 | for batch_ids, batch in encode_loader:
150 | lookup_indices.extend(batch_ids)
151 | with torch.cuda.amp.autocast() if fp16 else nullcontext():
152 | with torch.no_grad():
153 | for k, v in batch.items():
154 | batch[k] = v.to(device)
155 | if data_args.encode_is_qry:
156 | model_output = model(query=batch)
157 | encoded.append(model_output.q_reps.cpu().detach().numpy())
158 | else:
159 | model_output = model(passage=batch)
160 | encoded.append(model_output.p_reps.cpu().detach().numpy())
161 |
162 | encoded = np.concatenate(encoded)
163 |
164 | if encoding_file:
165 | with open(encoding_file, "wb") as f:
166 | pickle.dump((encoded, lookup_indices), f)
167 |
168 | return encoded
169 |
--------------------------------------------------------------------------------
/prompt2model/utils/tevatron_utils/retrieve.py:
--------------------------------------------------------------------------------
1 | """Tools for doing efficient similarity search via the Tevatron/faiss libraries."""
2 | from __future__ import annotations
3 |
4 | import pickle
5 |
6 | import numpy as np
7 | from tevatron.faiss_retriever import BaseFaissIPRetriever
8 |
9 |
10 | def retrieve_objects(
11 | query_vector: np.ndarray,
12 | encoded_datasets_path: str,
13 | document_names: list[str],
14 | depth: int,
15 | ) -> list[tuple[str, float]]:
16 | """Return a ranked list of object indices and their scores.
17 |
18 | Args:
19 | query vector: Vector representation of query.
20 | encoded_datasets_path: Path to file containing encoded dataset index.
21 | depth: Number of documents to return.
22 |
23 | Returns:
24 | Ranked list of object names and their inner product similarity to the query.
25 | """
26 | if query_vector.shape[0] != 1:
27 | raise ValueError("Only a single query vector is expected.")
28 | if len(query_vector.shape) != 2:
29 | raise ValueError("Query vector must be 1-D.")
30 |
31 | with open(encoded_datasets_path, "rb") as f:
32 | passage_reps, passage_lookup = pickle.load(f)
33 | retriever = BaseFaissIPRetriever(passage_reps)
34 | retriever.add(passage_reps)
35 |
36 | all_scores, all_indices = retriever.search(query_vector, depth)
37 | if not (len(all_scores) == len(all_indices) == 1):
38 | raise ValueError("Only one query's ranking should be returned.")
39 |
40 | psg_scores = all_scores[0]
41 | ranked_document_names = [document_names[passage_lookup[x]] for x in all_indices[0]]
42 | score_tuples = list(zip(ranked_document_names, psg_scores))
43 | return score_tuples
44 |
--------------------------------------------------------------------------------
/prompt2model/version.py:
--------------------------------------------------------------------------------
1 | """A version template file.
2 |
3 | This doesn't actually specify the version, which is specified.
4 | Through using tags on the repository.
5 | """
6 |
7 | VERSION = "0.0.0a0"
8 |
--------------------------------------------------------------------------------
/prompt_examples.md:
--------------------------------------------------------------------------------
1 | # Tips and Examples to Write a Good Prompt
2 |
3 | ## How to Write a Good Prompt
4 |
5 | A good prompt can make the generated dataset
6 | follow exactly the format of demonstrations.
7 | It contains the instruction and few-shot examples.
8 |
9 | The instruction should contain the following:
10 |
11 | 1. The exact format description for the input
12 | and output, i.e., a string, a dictionary, or whatever.
13 | 2. The exact contents of each part of the
14 | input and their relationship as possible as you can.
15 | 3. The range of possible input. For example,
16 | "And the question can range from Math, Cultural,
17 | Social, Geometry, Biology, History, Sports, Technology,
18 | Science, and so on."
19 |
20 | The few-shot examples should contain the following:
21 |
22 | 1. Use `=` rather than other ambiguous symbols like `:`.
23 | 2. Avoid unnecessary line breaks at the beginning.
24 | For example, `input=""` is better than breaking
25 | the line after `=`.
26 | 3. Use `input` rather than `Input`, `output` is
27 | preferable likewise.
28 | 4. Wrap the `input` and `output` into a string with `“”`.
29 |
30 | Though the examples are optional, we strongly
31 | suggest including them to guide the format and
32 | content for the generator.
33 |
34 | Also, we recommend providing several precise examples
35 | in the specified format and inquiring with ChatGPT
36 | about the format and scope of your examples.
37 |
38 | ## Examples of Good Prompts
39 |
40 | Here are some examples of good prompts:
41 |
42 | ### Question Answering
43 |
44 | ```text
45 | """Your task is to generate an answer to a natural question. In this task, the input is a string that consists of both a question and a context passage. The context is a descriptive passage related to the question and contains the answer. And the question can range from Math, Cultural, Social, Geometry, Biology, History, Sports, Technology, Science, and so on.
46 |
47 | Here are examples with input questions and context passages, along with their expected outputs:
48 |
49 | input="Question: What city did Super Bowl 50 take place in? Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50."
50 | output="Santa Clara"
51 |
52 | input="Question: What river runs through Warsaw? Context: Warsaw (Polish: Warszawa [varˈʂava] ( listen); see also other names) is the capital and largest city of Poland. It stands on the Vistula River in east-central Poland, roughly 260 kilometres (160 mi) from the Baltic Sea and 300 kilometres (190 mi) from the Carpathian Mountains. Its population is estimated at 1.740 million residents within a greater metropolitan area of 2.666 million residents, which makes Warsaw the 9th most-populous capital city in the European Union. The city limits cover 516.9 square kilometres (199.6 sq mi), while the metropolitan area covers 6,100.43 square kilometres (2,355.39 sq mi)."
53 | output="Vistula River"
54 |
55 | input="Question: The Ottoman empire controlled territory on three continents, Africa, Asia and which other? Context: The Ottoman Empire was an imperial state that lasted from 1299 to 1923. During the 16th and 17th centuries, in particular at the height of its power under the reign of Suleiman the Magnificent, the Ottoman Empire was a powerful multinational, multilingual empire controlling much of Southeast Europe, Western Asia, the Caucasus, North Africa, and the Horn of Africa. At the beginning of the 17th century the empire contained 32 provinces and numerous vassal states. Some of these were later absorbed into the empire, while others were granted various types of autonomy during the course of centuries."
56 | output="Europe"
57 | """
58 | ```
59 |
60 | ### Temporal Expression Normalization
61 |
62 | ```text
63 | """Temporal date expressions are commonly used to refer to specific time periods. Your task is to identify these temporal date expressions and provide the exact dates they refer to.
64 |
65 | For this task, the input is a string containing two specific elements: a posted date in the format "[Posted: YYYY-MM-DD]" and a sentence or statement that contains various temporal date references (e.g., early December, the end of the year, today, August, last Christmas, next Month, etc).
66 |
67 | Your program should output a string that maps the time period references mentioned in the input to their corresponding dates, following these strict rules:
68 |
69 | 1. If temporal date references are found, the output should use either "YYYY-MM-DD", "YYYY-MM", or "YYYY" to represent the exact date.
70 | - If multiple time period references are found, separate them using '|'.
71 | 2. If no temporal date reference is found or the referred date is ambiguous, the output should just be 'N/A', i.e., output="N/A".
72 |
73 | Here are some examples:
74 |
75 | input="[Posted: 1998-09-07] Tourism industry revenues reportedly dropped to $300 million last year, down from $450 million the year before."
76 | output="last year == 1997"
77 |
78 | input="[Posted: 2013-09-27] Eat! @mepangilinan"
79 | output="N/A"
80 |
81 | input="[Posted: 1989-10-30] Rated single-B-1 by Moody's Investors Service Inc. and single-B-plus by Standard amp Poor's Corp., the issue will be sold through underwriters led by Goldman, Sachs amp Co. Hertz Corp. -- $100 million of senior notes due Nov. 1, 2009, priced at par to yield 9%."
82 | output="Nov. 1, 2009 == 2009-11-01"
83 |
84 | input="[Posted: 2014-07-11] So out of place with this early transfer business."
85 | output="N/A"
86 |
87 | input="[Posted: 2013-10-25] Quote of the Day: '#Yoga is what you learn on your way down!"
88 | output="the Day == 2013-10-25"
89 |
90 | input="[Posted: 2021-06-15] Google plans to develop PALM 2 model in the first quarter of next year."
91 | output="N/A"
92 |
93 | input="[Posted: 2013-03-22] We will release a new github repository in the next three months."
94 | output="the next three month == 2013-04"
95 |
96 | input="[Posted: 2022-05-17] The company's fiscal year starts on July 1st and ends on June 30th."
97 | output="July 1st == 2022-07-01 | June 30th == 2022-06-30"
98 |
99 | input="[Posted: 2013-03-22] This flu season started in early December, a month earlier than usual, and peaked by the end of year."
100 | output="N/A"
101 |
102 | input="[Posted: 1989-10-30] The issue, which is puttable back to the company in 1999, was priced at a spread of 110 basis points above the Treasury's 10-year note."
103 | output="1999 == 1999"
104 |
105 | input="[Posted: 2022-04-15] The company announced that they will release their new product at the end of next month."
106 | output="the end of next month == 2022-05-31"
107 |
108 | input="[Posted: 2022-03-15] The teacher is going to release a new assignment in a few days."
109 | output="N/A"
110 | """
111 | ```
112 |
113 | ### Japanese-to-Python Generation
114 |
115 | ```text
116 | """Pythonで1行のコードを生成し、StackOverflowの日本語の質問を解決してください。コメントや式は含めないでください。インポート文も不要です。
117 |
118 | このタスクでは、入力は日本語のテキストで、変数名や操作が記述されています。出力は、そのタスクを達成するためのPythonの1行のコードです。コメントや式は含めないでください。インポート文も不要です。
119 |
120 | input="スペースで区切られた入力`stdin`を変数に格納して表示する"
121 | output="for line in stdin: a = line.rstrip().split(' ') print(a)"
122 |
123 | input="リスト`word_list'内に出現する単語を数える"
124 | output="Counter(word_list)"
125 |
126 | input="tweepyインスタンス`api`を使い、文字列`word`を含んだツイートを検索し、結果をリストとして得る"
127 | output="search = api.search(q=word)"
128 |
129 | input="データベースの設定を表示する"
130 | output="print(settings.DATABASES)"
131 |
132 | input="ネストされているリスト`li`を見やすく表示する"
133 | output="pprint.pprint(li)"
134 |
135 | input="HTMLファイル'test.html'を開き、テキストオブジェクト'text'をutf-8で保存する"
136 | output="f = open('test.html', 'w') f.write(text.encode('utf-8'))"
137 | """
138 | ```
139 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "hatchling",
4 | ]
5 | build-backend = "hatchling.build"
6 |
7 | [project]
8 | name = "prompt2model"
9 | authors = [
10 | {name = "Vijay Viswanathan", email = "vijayv@andrew.cmu.edu"},
11 | {name = "Chenyang Zhao", email = "zhaochen20@mails.tsinghua.edu.cn"},
12 | ]
13 | description = "A library for distilling models from prompts."
14 | readme = "README.md"
15 | repository = "https://github.com/neulab/prompt2model"
16 | requires-python = ">=3.9"
17 | license = {file = "LICENSE"}
18 | classifiers = [
19 | "Programming Language :: Python :: 3.9",
20 | "Programming Language :: Python :: 3.10",
21 | "Programming Language :: Python :: 3.11",
22 | ]
23 | dependencies = [
24 | "transformers",
25 | "datasets",
26 | "pandas",
27 | "fastapi",
28 | "gradio==3.38.0",
29 | "torch",
30 | "pytest",
31 | "openai",
32 | "sentencepiece",
33 | "bert_score",
34 | "sacrebleu",
35 | "evaluate",
36 | "tevatron",
37 | "faiss-cpu",
38 | "mdtex2html",
39 | "scikit-learn",
40 | "retriv",
41 | "tiktoken",
42 | "aiolimiter",
43 | "pyfiglet",
44 | "termcolor",
45 | "psutil",
46 | "protobuf",
47 | "nest-asyncio",
48 | "litellm",
49 | "peft"
50 | ]
51 |
52 | dynamic = ["version"]
53 |
54 | [project.optional-dependencies]
55 | test = [
56 | "pytest>=6.0.0",
57 | ]
58 | dev = [
59 | "pytest",
60 | "pre-commit"
61 | ]
62 |
63 | [tool.hatch.build]
64 | include = [
65 | "*.py",
66 | ]
67 | exclude = [
68 | "*_test.py",
69 | "test_*.py",
70 | ]
71 | only-packages = true
72 |
73 | [tool.hatch.build.targets.wheel]
74 | packages = ["prompt2model"]
75 |
76 | [tool.hatch.version]
77 | path = "prompt2model/version.py"
78 |
79 | [tool.pytest.ini_options]
80 | pythonpath = ["prompt2model/test_helpers"]
81 | testpaths = ["prompt2model/tests"]
82 |
--------------------------------------------------------------------------------
/scripts/dataset_index/preprocessing.py:
--------------------------------------------------------------------------------
1 | """Filtering out datasets before we do heavy processing on them."""
2 | import argparse
3 | import json
4 | import os
5 | from typing import Any
6 |
7 | from huggingface_hub import list_datasets
8 |
9 |
10 | def parse_arguments():
11 | """Parse command line arguments for the script.
12 |
13 | Returns:
14 | argparse.Namespace: An object with the following attributes:
15 | - unprocessed_datasets_file (str): Filename for unprocessed datasets.
16 | - preprocessed_datasets_file (str): Filename for preprocessed datasets.
17 | - min_words_in_desc (int): Minimum words in a dataset description.
18 | - min_downloads (int): Minimum downloads for a dataset.
19 |
20 | """
21 | parser = argparse.ArgumentParser(description="Process dataset files.")
22 | parser.add_argument(
23 | "--unprocessed_datasets_file",
24 | type=str,
25 | default="unprocessed.json",
26 | help="File name for unprocessed datasets",
27 | )
28 | parser.add_argument(
29 | "--preprocessed_datasets_file",
30 | type=str,
31 | default="preprocessed_datasets.json",
32 | help="File name for preprocessed datasets",
33 | )
34 | parser.add_argument(
35 | "--min_words_in_desc",
36 | type=int,
37 | default=4,
38 | help="Minimum number of words in description",
39 | )
40 | parser.add_argument(
41 | "--min_downloads", type=int, default=10, help="Minimum number of downloads"
42 | )
43 |
44 | args = parser.parse_args()
45 | return args
46 |
47 |
48 | def load_datasets(file_path: str) -> list[dict[str, Any]]:
49 | """Load all huggingface datasets from a JSON file.
50 |
51 | Check if the unfiltered dataset file exists, if not generate it.
52 | Generating it is helpful for multiple iterations of preprocessing.
53 |
54 | Args:
55 | file_path: File path of unfiltered datasets.
56 |
57 | Returns:
58 | List of datasets from HuggingFace. This doesn't have configs or rows.
59 | """
60 | if not os.path.exists(file_path):
61 | all_datasets = list(list_datasets())
62 | with open(file_path, "w") as f:
63 | ds = json.dumps([ob.__dict__ for ob in all_datasets])
64 | f.write(ds)
65 | return all_datasets
66 | else:
67 | with open(file_path, "r") as file:
68 | return json.load(file)
69 |
70 |
71 | def filter_datasets(
72 | datasets: list[dict[str, Any]], min_downloads, min_words_in_desc
73 | ) -> list[dict[str, Any]]:
74 | """Filter datasets based on specific criteria.
75 |
76 | Filter if description is None, if number of words in description is < 4,
77 | if number of downloads is less than a threshold, or if it is a duplicated
78 | dataset (if the description is the same).
79 |
80 | Args:
81 | datasets: List of datasets from HuggingFace.
82 |
83 | Returns:
84 | Datasets filtered based on above criteria.
85 |
86 | """
87 | filtered_datasets = []
88 | descr_none = descr_small = downloads_less = common_descr = 0
89 | unique_descriptions: set[str] = set()
90 |
91 | for dataset_info in datasets:
92 | description = dataset_info.get("description")
93 |
94 | if not description:
95 | descr_none += 1
96 | continue
97 | if len(description.split()) < min_words_in_desc:
98 | descr_small += 1
99 | continue
100 | if dataset_info.get("downloads", 0) < min_downloads:
101 | downloads_less += 1
102 | continue
103 | if description in unique_descriptions:
104 | common_descr += 1
105 | continue
106 |
107 | filtered_datasets.append(dataset_info)
108 | unique_descriptions.add(description)
109 |
110 | print(f"{descr_none=}, {descr_small=}, {downloads_less=}, {common_descr=}")
111 |
112 | return filtered_datasets
113 |
114 |
115 | def main(args):
116 | """Main function to load and filter datasets."""
117 | datasets = load_datasets(args.unprocessed_datasets_file)
118 | filtered_datasets = filter_datasets(
119 | datasets=datasets,
120 | min_downloads=args.min_downloads,
121 | min_words_in_desc=args.min_words_in_desc,
122 | )
123 | with open(args.preprocessed_datasets_file, "w") as f:
124 | json.dump(filtered_datasets, f)
125 |
126 |
127 | if __name__ == "__main__":
128 | args = parse_arguments()
129 | main(args)
130 |
--------------------------------------------------------------------------------
/test_helpers/__init__.py:
--------------------------------------------------------------------------------
1 | """Import mock classes used in unit tests."""
2 | from test_helpers.mock_api import (
3 | MockBatchDifferentCompletions,
4 | MockCompletion,
5 | UnknownGpt3Exception,
6 | mock_batch_api_response_identical_completions,
7 | )
8 | from test_helpers.mock_retrieval import create_test_search_index
9 | from test_helpers.model_and_tokenizer import (
10 | create_gpt2_model_and_tokenizer,
11 | create_t5_model_and_tokenizer,
12 | )
13 |
14 | __all__ = (
15 | "MockCompletion",
16 | "UnknownGpt3Exception",
17 | "MockBatchDifferentCompletions",
18 | "create_gpt2_model_and_tokenizer",
19 | "create_t5_model_and_tokenizer",
20 | "create_test_search_index",
21 | "mock_batch_api_response_identical_completions",
22 | "are_dataset_dicts_identical",
23 | "are_datasets_identical",
24 | "MockBatchResponseDifferentCompletions",
25 | )
26 |
--------------------------------------------------------------------------------
/test_helpers/dataset_index_tiny.json:
--------------------------------------------------------------------------------
1 | {"squad": {"name": "squad", "description": "Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.", "evaluation_metadata": [{"config": "plain_text", "task": "question-answering", "task_id": "extractive_question_answering", "splits": {"train_split": "train", "eval_split": "validation"}, "col_mapping": {"question": "question", "context": "context", "answers": {"text": "text", "answer_start": "answer_start"}}, "metrics": [{"type": "squad", "name": "SQuAD"}]}]}, "trivia_qa": {"name": "trivia_qa", "description": "TriviaqQA is a reading comprehension dataset containing over 650K\nquestion-answer-evidence triples. TriviaqQA includes 95K question-answer\npairs authored by trivia enthusiasts and independently gathered evidence\ndocuments, six per question on average, that provide high quality distant\nsupervision for answering the questions.", "evaluation_metadata": {}}, "search_qa": {"name": "search_qa", "description": "We publicly release a new large-scale dataset, called SearchQA, for machine comprehension, or question-answering. Unlike recently released datasets, such as DeepMind\nCNN/DailyMail and SQuAD, the proposed SearchQA was constructed to reflect a full pipeline of general question-answering. That is, we start not from an existing article\nand generate a question-answer pair, but start from an existing question-answer pair, crawled from J! Archive, and augment it with text snippets retrieved by Google.\nFollowing this approach, we built SearchQA, which consists of more than 140k question-answer pairs with each pair having 49.6 snippets on average. Each question-answer-context\n tuple of the SearchQA comes with additional meta-data such as the snippet's URL, which we believe will be valuable resources for future research. We conduct human evaluation\n as well as test two baseline methods, one simple word selection and the other deep learning based, on the SearchQA. We show that there is a meaningful gap between the human\n and machine performances. This suggests that the proposed dataset could well serve as a benchmark for question-answering.", "evaluation_metadata": {}}}
--------------------------------------------------------------------------------
/test_helpers/mock_api.py:
--------------------------------------------------------------------------------
1 | """Tools for mocking API responses (for testing purposes)."""
2 |
3 | from __future__ import annotations
4 |
5 | import openai
6 |
7 | from prompt2model.utils.api_tools import APIAgent
8 |
9 |
10 | class MockCompletion:
11 | """Mock completion object."""
12 |
13 | def __init__(self, content: str | None = None, responses_per_request: int = 1):
14 | """Initialize a new instance of `MockCompletion` class.
15 |
16 | Args:
17 | content: The mocked content to be returned, i.e.,
18 | `json.dumps({"comment": "This is a great movie!",
19 | "label": 1})`.
20 | responses_per_request: Number of responses
21 | for each request.
22 | """
23 | # We generate 5 identical responses for each API call by default.
24 | if content is not None:
25 | # Mock a ChatCompletion with identical responses.
26 | self.choices = [{"message": {"content": content}}] * responses_per_request
27 | else:
28 | # Mock a ChatCompletion with different responses.
29 | # Only used in mock_batch_api_response_with_different_completion.
30 | # The choice will be replaced later in the function.
31 | self.choices = []
32 |
33 | def __repr__(self):
34 | """Return a string representation.
35 |
36 | Returns:
37 | _string: A string representation of the object, including its choices.
38 | """
39 | _string = f""
40 | return _string
41 |
42 |
43 | class MockBatchDifferentCompletions:
44 | """Mock batch completion object."""
45 |
46 | def __init__(self, length: int = 4) -> None:
47 | """Init a new instance of `MockBatchDifferentCompletions`.
48 |
49 | Args:
50 | length: Length of the batch completions.
51 |
52 | This class is designed to simulate the response of APIAgent and test the
53 | generation process of the PromptBasedDatasetGenerator with
54 | `filter_duplicated_examples` set to True in
55 | `dataset_generator_with_filter_test`.
56 |
57 | The class works in conjunction with PromptBasedDatasetGenerator with
58 | batch_size = 2, responses_per_request = 3, expected_num_examples
59 | = 5, and filter_duplicated_examples = True.
60 |
61 | Explanation of the generation process:
62 |
63 | In the first API call, the generator produces 2 * 3 = 6 responses.
64 | After filtering duplicates, the generated_dataset will be:
65 |
66 | Dataset.from_dict({
67 | "input_col": ["1", "2"],
68 | "output_col": ["a", "a"],
69 | })
70 |
71 | The second API call reduces batch_size to 1 and generates 3 more
72 | responses. After filtering duplicates, the generated_dataset will be:
73 |
74 | Dataset.from_dict({
75 | "input_col": ["1", "2", "3"],
76 | "output_col": ["a", "a", "a"],
77 | })
78 |
79 | The third API call again uses batch_size = 1 and generates another
80 | 3 responses. After filtering duplicates, the generated_dataset will be:
81 |
82 | Dataset.from_dict({
83 | "input_col": ["1", "2", "3"],
84 | "output_col": ["b", "a", "a"],
85 | })
86 |
87 | The fourth and API call also uses batch_size = 1 and generates the final
88 | 3 responses. After filtering duplicates, the generated_dataset will be:
89 |
90 | Dataset.from_dict({
91 | "input_col": ["1", "2", "3", "4", "5"],
92 | "output_col": ["b", "a", "a", "c", "a"],
93 | })
94 |
95 | The fivth and API call is specifically designed for
96 | testing generate dataset_dict.
97 | """
98 | assert length == 4 or length == 5
99 | self.mock_completions: list[list[MockCompletion]] = []
100 | self.current_index = 0
101 | mock_completion_1 = MockCompletion()
102 | mock_completion_1.choices = [
103 | {"message": {"content": '{"input": "1", "output": "a"}'}},
104 | {"message": {"content": '{"input": "1", "output": "b"}'}},
105 | {"message": {"content": '{"input": "1", "output": "a"}'}},
106 | ]
107 | mock_completion_2 = MockCompletion()
108 | mock_completion_2.choices = [
109 | {"message": {"content": '{"input": "1", "output": "c"}'}},
110 | {"message": {"content": '{"input": "2", "output": "a"}'}},
111 | {"message": {"content": '{"input": "2", "output": "b"}'}},
112 | ]
113 | self.mock_completions.append(
114 | [
115 | mock_completion_1,
116 | mock_completion_2,
117 | ]
118 | )
119 | mock_completion_3 = MockCompletion()
120 | mock_completion_3.choices = [
121 | {"message": {"content": '{"input": "3", "output": "a"}'}},
122 | {"message": {"content": '{"input": "3", "output": "a"}'}},
123 | {"message": {"content": '{"input": "3", "output": "b"}'}},
124 | ]
125 | self.mock_completions.append([mock_completion_3])
126 |
127 | mock_completion_4 = MockCompletion()
128 | mock_completion_4.choices = [
129 | {"message": {"content": '{"input": "1", "output": "b"}'}},
130 | {"message": {"content": '{"input": "1", "output": "b"}'}},
131 | {"message": {"content": '{"input": "1", "output": "b"}'}},
132 | ]
133 | self.mock_completions.append([mock_completion_4])
134 | mock_completion_5 = MockCompletion()
135 | mock_completion_5.choices = [
136 | {"message": {"content": '{"input": "4", "output": "c"}'}},
137 | {"message": {"content": '{"input": "4", "output": "c"}'}},
138 | {"message": {"content": '{"input": "5", "output": "a"}'}},
139 | ]
140 | self.mock_completions.append([mock_completion_5])
141 | if length == 5:
142 | self.mock_completions.append(
143 | [
144 | mock_completion_1,
145 | mock_completion_2,
146 | ]
147 | )
148 |
149 |
150 | def mock_batch_api_response_identical_completions(
151 | prompts: list[str],
152 | content: str,
153 | temperature: float,
154 | presence_penalty: float = 0,
155 | frequency_penalty: float = 0,
156 | responses_per_request: int = 5,
157 | requests_per_minute: int = 80,
158 | ) -> list[MockCompletion]:
159 | """Generate a batch of mock completion objects.
160 |
161 | This function creates a batch of `MockCompletion`
162 | object with a `content` attribute set to an LLM completion string.
163 |
164 | Args:
165 | prompts: A batch of mocked prompts that won't be used.
166 | content: The example string to be returned.
167 | temperature: A mocked temperature.
168 | presence_penalty: A mocked presence penalty.
169 | frequency_penalty: A mocked frequency penalty.
170 | responses_per_request: Number of responses for each request.
171 | requests_per_minute: Number of requests per minute to allow.
172 |
173 | Returns:
174 | A mock completion object simulating an ChatCompletion API response.
175 | """
176 | _ = prompts, temperature, presence_penalty, frequency_penalty, requests_per_minute
177 | mock_completions = [
178 | MockCompletion(content=content, responses_per_request=responses_per_request)
179 | for _ in prompts
180 | ]
181 | return mock_completions
182 |
183 |
184 | class MockAPIAgent(APIAgent):
185 | """A mock API agent that always returns the same content."""
186 |
187 | def __init__(self, default_content):
188 | """Initialize the API agent."""
189 | self.generate_one_call_counter = 0
190 | self.generate_batch_call_counter = 0
191 | self.default_content = default_content
192 |
193 | def generate_one_completion(
194 | self,
195 | prompt: str,
196 | temperature: float = 0,
197 | presence_penalty: float = 0,
198 | frequency_penalty: float = 0,
199 | token_buffer: int = 300,
200 | ) -> openai.Completion:
201 | """Return a mocked object and increment the counter."""
202 | self.generate_one_call_counter += 1
203 | return MockCompletion(content=self.default_content)
204 |
205 | async def generate_batch_completion(
206 | self,
207 | prompts: list[str],
208 | temperature: float = 1,
209 | responses_per_request: int = 5,
210 | requests_per_minute: int = 80,
211 | token_buffer: int = 300,
212 | ) -> list[openai.Completion]:
213 | """Return a mocked object and increment the counter."""
214 | self.generate_batch_call_counter += 1
215 | return [MockCompletion(content=self.default_content) for _ in prompts]
216 |
217 |
218 | class UnknownGpt3Exception(Exception):
219 | """This is a newly-defined exception for testing purposes."""
220 |
221 | pass
222 |
--------------------------------------------------------------------------------
/test_helpers/mock_retrieval.py:
--------------------------------------------------------------------------------
1 | """Tools for creating a mock search index."""
2 |
3 | import pickle
4 |
5 | import numpy as np
6 |
7 |
8 | def create_test_search_index(index_file_name: str) -> None:
9 | """Utility function to create a test search index.
10 |
11 | This search index represents 3 models, each represented with a hand-written vector.
12 | Given a query of [0, 0, 1], the 3rd model will be the most similar.
13 |
14 | Args:
15 | index_file_name: The name of the file to which a pickled index will be written.
16 | """
17 | mock_model_encodings = np.array([[0.9, 0, 0], [0, 0.9, 0], [0, 0, 0.9]])
18 | mock_lookup_indices = [0, 1, 2]
19 | with open(index_file_name, "wb") as f:
20 | pickle.dump((mock_model_encodings, mock_lookup_indices), f)
21 |
--------------------------------------------------------------------------------
/test_helpers/model_and_tokenizer.py:
--------------------------------------------------------------------------------
1 | """Tools for creating a GPT-2 model with padding tokenizer."""
2 |
3 | from collections import namedtuple
4 |
5 | from transformers import AutoModelForCausalLM, AutoTokenizer, T5ForConditionalGeneration
6 |
7 | ModelAndTokenizer = namedtuple("ModelAndTokenizer", ["model", "tokenizer"])
8 |
9 |
10 | def create_gpt2_model_and_tokenizer(full_size: bool = False) -> ModelAndTokenizer:
11 | """Create a GPT2 model with its padding tokenizer for batched input.
12 |
13 | Args:
14 | full_size: Whether to use the full size of the GPT-2 model. Defaults to False.
15 | Note that the full size of the GPT-2 model may occupy too much memory
16 | that lead to out of memory errors.
17 |
18 | Returns:
19 | gpt2_model_and_tokenizer: A namedtuple with gpt2 model and tokenizer.
20 | """
21 | if not full_size:
22 | gpt2_model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
23 | gpt2_tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
24 | else:
25 | gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2")
26 | gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
27 | if gpt2_tokenizer.pad_token is None:
28 | gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
29 | if gpt2_model.config.pad_token_id is None:
30 | gpt2_model.config.pad_token_id = gpt2_tokenizer.eos_token_id
31 | gpt2_model_and_tokenizer = ModelAndTokenizer(gpt2_model, gpt2_tokenizer)
32 | return gpt2_model_and_tokenizer
33 |
34 |
35 | def create_t5_model_and_tokenizer(full_size: bool = False) -> ModelAndTokenizer:
36 | """Create a T5 model with its padding tokenizer for batched input.
37 |
38 | Args:
39 | full_size: Whether to use the full size of the T5 model. Defaults to False.
40 | Note that the full size of the T5 model may occupy too much memory
41 | that lead to out of memory errors.
42 |
43 | Returns:
44 | t5_model_and_tokenizer: A namedtuple with t5 model and tokenizer.
45 | """
46 | if not full_size:
47 | t5_model = T5ForConditionalGeneration.from_pretrained(
48 | "google/t5-efficient-tiny"
49 | )
50 | t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-efficient-tiny")
51 | else:
52 | t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
53 | t5_tokenizer = AutoTokenizer.from_pretrained("t5-small")
54 | t5_model_and_tokenizer = ModelAndTokenizer(t5_model, t5_tokenizer)
55 | return t5_model_and_tokenizer
56 |
--------------------------------------------------------------------------------
/test_helpers/model_info_tiny/gpt2.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_model_name": "gpt2",
3 | "description": "unidirectional autoregressive language model",
4 | "size_bytes": "800000000",
5 | "downloads": 100,
6 | "_comment": "THIS FILE SHOULD ONLY BE USED FOR TESTING PURPOSES."
7 | }
--------------------------------------------------------------------------------
/test_helpers/model_info_tiny/t5-base.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_model_name": "t5-base",
3 | "description": "text to text generator",
4 | "size_bytes": "800000000",
5 | "downloads": 10000,
6 | "_comment": "THIS FILE SHOULD ONLY BE USED FOR TESTING PURPOSES."
7 | }
--------------------------------------------------------------------------------
/test_helpers/model_info_tiny/xlnet-base-cased.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_model_name": "xlnet-base-cased",
3 | "description": "bidirectional autoregressive language model",
4 | "size_bytes": "1600000000",
5 | "downloads": 100000,
6 | "_comment": "THIS FILE SHOULD ONLY BE USED FOR TESTING PURPOSES."
7 | }
--------------------------------------------------------------------------------
/test_helpers/test_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for testing."""
2 | from contextlib import contextmanager
3 |
4 |
5 | @contextmanager
6 | def temp_setattr(obj, attr, value):
7 | """Temporarily set an attribute on an object."""
8 | original = getattr(obj, attr, None)
9 | setattr(obj, attr, value)
10 | try:
11 | yield
12 | finally:
13 | if original is not None:
14 | setattr(obj, attr, original)
15 | else:
16 | delattr(obj, attr)
17 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Init file for test scripts."""
2 |
--------------------------------------------------------------------------------
/tests/dataset_transformer_test.py:
--------------------------------------------------------------------------------
1 | """Tests for dataset_transformer."""
2 |
3 | from functools import partial
4 | from unittest.mock import patch
5 |
6 | from datasets import Dataset, DatasetDict
7 |
8 | from prompt2model.dataset_transformer import PromptBasedDatasetTransformer
9 | from prompt2model.prompt_parser import MockPromptSpec, TaskType
10 | from test_helpers import MockCompletion, mock_batch_api_response_identical_completions
11 |
12 | TASK_EXPLANATION = MockCompletion(
13 | """This is a question answering task. The model is given a context and a question, and it must provide the answer.""" # noqa E501
14 | )
15 | PLAN_RESPONSE = MockCompletion(
16 | """plan:
17 | 1. Combine "context" and "question" into "input".
18 | 2. Combine "answer" into "output"."""
19 | )
20 |
21 | TRANSFORMED_DATA = partial(
22 | mock_batch_api_response_identical_completions,
23 | content='{"input": "context: Albany, the state capital, is the sixth-largest city in the State of New York.\nquestion: What is the capital of New York?", "output": "Albany"}', # noqa E501
24 | )
25 |
26 |
27 | @patch(
28 | "prompt2model.utils.APIAgent.generate_batch_completion",
29 | side_effect=TRANSFORMED_DATA,
30 | )
31 | @patch(
32 | "prompt2model.utils.APIAgent.generate_one_completion",
33 | side_effect=[TASK_EXPLANATION, PLAN_RESPONSE],
34 | )
35 | def test_transform_data(mock_batch_completion, mock_one_completion):
36 | """Test transform_data."""
37 | dataset_transformer = PromptBasedDatasetTransformer(
38 | num_points_to_transform=1, max_allowed_failed_transforms=0
39 | ) # noqa: E501
40 | prompt_spec = MockPromptSpec(
41 | TaskType.TEXT_GENERATION,
42 | instruction="instruction",
43 | examples="example1\nexample2",
44 | )
45 | mock_dataset = Dataset.from_dict(
46 | {
47 | "question": ["What is the capital of New York?"],
48 | "context": [
49 | "Albany, the state capital, is the sixth-largest city in the State of New York." # noqa E501
50 | ],
51 | "answer": ["Albany"],
52 | }
53 | )
54 | # Create a mock DatasetDict consisting of the same example in each split.
55 | dataset = DatasetDict(
56 | {"train": mock_dataset, "val": mock_dataset, "test": mock_dataset}
57 | )
58 | inputs, outputs = dataset_transformer.transform_data(
59 | prompt_spec=prompt_spec,
60 | dataset=dataset["train"],
61 | )
62 | assert inputs == [
63 | "context: Albany, the state capital, is the sixth-largest city in the State of New York.\nquestion: What is the capital of New York?" # noqa E501
64 | ] # noqa E501
65 | assert outputs == ["Albany"]
66 |
--------------------------------------------------------------------------------
/tests/demo_creator_test.py:
--------------------------------------------------------------------------------
1 | """Test the create_gradio function with two configurations."""
2 |
3 | import gc
4 |
5 | import gradio as gr
6 |
7 | from prompt2model.demo_creator import create_gradio
8 | from prompt2model.model_executor import GenerationModelExecutor
9 | from prompt2model.prompt_parser import MockPromptSpec, TaskType
10 | from test_helpers import create_gpt2_model_and_tokenizer, create_t5_model_and_tokenizer
11 |
12 |
13 | def test_create_gradio_with_gpt2():
14 | """Test the `create_gradio` method with a GPT2 model."""
15 | # Create a GPT-2 model and tokenizer.
16 | gpt2_model_and_tokenizer = create_gpt2_model_and_tokenizer()
17 | gpt2_model = gpt2_model_and_tokenizer.model
18 | gpt2_tokenizer = gpt2_model_and_tokenizer.tokenizer
19 |
20 | # Create a GenerationModelExecutor.
21 | gpt2_executor = GenerationModelExecutor(
22 | model=gpt2_model,
23 | tokenizer=gpt2_tokenizer,
24 | batch_size=1,
25 | )
26 |
27 | # Create PromptBasedInstructionParser.
28 | gpt2_prompt_parser = MockPromptSpec(TaskType.TEXT_GENERATION)
29 |
30 | # Create Gradio interface.
31 | interface_gpt2 = create_gradio(gpt2_executor, gpt2_prompt_parser)
32 |
33 | # Perform assertions.
34 | assert isinstance(interface_gpt2, gr.Blocks)
35 | gc.collect()
36 |
37 |
38 | def test_create_gradio_with_t5():
39 | """Test the `create_gradio` method with a T5 model."""
40 | # Create T5 model and tokenizer
41 | t5_model_and_tokenizer = create_t5_model_and_tokenizer()
42 | t5_model = t5_model_and_tokenizer.model
43 | t5_tokenizer = t5_model_and_tokenizer.tokenizer
44 |
45 | # Create a GenerationModelExecutor.
46 | t5_executor = GenerationModelExecutor(
47 | model=t5_model,
48 | tokenizer=t5_tokenizer,
49 | batch_size=1,
50 | )
51 |
52 | # Create PromptBasedInstructionParser.
53 | t5_prompt_parser = MockPromptSpec(task_type="generation")
54 |
55 | # Create Gradio interface.
56 | interface_t5 = create_gradio(t5_executor, t5_prompt_parser)
57 |
58 | # Perform assertions.
59 | assert isinstance(interface_t5, gr.Blocks)
60 | gc.collect()
61 |
--------------------------------------------------------------------------------
/tests/param_selector_test.py:
--------------------------------------------------------------------------------
1 | """Testing hyperparameter optimization with different configurations."""
2 |
3 | import gc
4 | import logging
5 | import os
6 |
7 | import datasets
8 |
9 | from prompt2model.model_trainer.generate import GenerationModelTrainer
10 | from prompt2model.param_selector.search_with_optuna import OptunaParamSelector
11 |
12 | os.environ["WANDB_MODE"] = "dryrun"
13 | logger = logging.getLogger("AutoHyperparamOptimization")
14 |
15 |
16 | def test_optimize_hyperparameters():
17 | """Tests whether the hyperparameter optimization is working correctly or not."""
18 | # Create a simple training dataset.
19 | training_datasets = [
20 | datasets.Dataset.from_dict(
21 | {
22 | "model_input": [
23 | "Given a product review, predict the sentiment score associated with it.\nExample:\nIt isn’t my fav lip balm, but it’s up there. It moisturises really well and the lemon isn’t strong or over powering.\nLabel:\n", # noqa E501
24 | ],
25 | "model_output": ["4"],
26 | }
27 | ),
28 | ]
29 |
30 | validation_datasets = [
31 | datasets.Dataset.from_dict(
32 | {
33 | "model_input": [
34 | "Given a product review, predict the sentiment score associated with it.\nExample:\nBroke me out and gave me awful texture all over my face. I typically have clear skin and after using this product my skin HATED it. Could work for you though.\nLabel:\n", # noqa E501
35 | ],
36 | "model_output": ["2"],
37 | }
38 | ),
39 | ]
40 |
41 | param_selector = OptunaParamSelector(
42 | n_trials=1,
43 | trainer=GenerationModelTrainer(
44 | "patrickvonplaten/t5-tiny-random", has_encoder=True
45 | ),
46 | )
47 | best_hyperparameters = param_selector.optimize_hyperparameters(
48 | training_datasets=training_datasets,
49 | validation=validation_datasets,
50 | hyperparameters={
51 | "min_num_train_epochs": 1,
52 | "max_num_train_epochs": 1,
53 | "save_strategy": ["epoch"],
54 | "evaluation_strategy": ["epoch"],
55 | "per_device_train_batch_size": [2],
56 | "min_weight_decay": 4e-5,
57 | "max_weight_decay": 1e-1,
58 | "min_learning_rate": 4e-5,
59 | "max_learning_rate": 1e-1,
60 | },
61 | )
62 | assert isinstance(best_hyperparameters, dict)
63 | gc.collect()
64 |
--------------------------------------------------------------------------------
/tests/run_locally_test.py:
--------------------------------------------------------------------------------
1 | """Testing integration of components locally."""
2 |
3 | import gc
4 | import os
5 | import tempfile
6 |
7 | from prompt2model.run_locally import run_skeleton
8 |
9 |
10 | def test_integration():
11 | """Check that a end-to-end run with a single prompt doesn't throw an error."""
12 | prompt = ["Test prompt"]
13 | with tempfile.TemporaryDirectory() as tmpdirname:
14 | metrics_output_path = os.path.join(tmpdirname, "metrics.json")
15 | run_skeleton(prompt, metrics_output_path)
16 | gc.collect()
17 |
--------------------------------------------------------------------------------
/tests/tevatron_utils_test.py:
--------------------------------------------------------------------------------
1 | """Testing DatasetGenerator through PromptBasedDatasetGenerator."""
2 |
3 | import gc
4 | import json
5 | import os
6 | import pickle
7 | import tempfile
8 |
9 | import numpy as np
10 | import pytest
11 | from tevatron.modeling import DenseModelForInference
12 | from transformers import PreTrainedTokenizerBase
13 |
14 | from prompt2model.utils.tevatron_utils import encode_text, retrieve_objects
15 | from prompt2model.utils.tevatron_utils.encode import load_tevatron_model
16 |
17 | TINY_MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2"
18 |
19 |
20 | def test_load_tevatron_model():
21 | """Test loading a small Tevatron model."""
22 | model, tokenizer = load_tevatron_model(TINY_MODEL_NAME)
23 | assert isinstance(model, DenseModelForInference)
24 | assert isinstance(tokenizer, PreTrainedTokenizerBase)
25 | gc.collect()
26 |
27 |
28 | def test_encode_text_from_string():
29 | """Test encoding text from a string into a vector."""
30 | text = "This is an example sentence"
31 | encoded = encode_text(TINY_MODEL_NAME, text_to_encode=text)
32 | assert encoded.shape == (1, 128)
33 | gc.collect()
34 |
35 |
36 | def test_encode_text_from_file():
37 | """Test encoding text from a file into a vector."""
38 | text_rows = [
39 | {"text_id": 0, "text": "This is an example sentence"},
40 | {"text_id": 1, "text": "This is another example sentence"},
41 | ]
42 | with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as f:
43 | json.dump(text_rows, f)
44 | f.seek(0)
45 | encoded = encode_text(TINY_MODEL_NAME, file_to_encode=f.name)
46 | assert encoded.shape == (2, 128)
47 | gc.collect()
48 |
49 |
50 | def test_encode_text_from_file_store_to_file():
51 | """Test encoding text from a file into a vector, then stored to file."""
52 | text_rows = [
53 | {"text_id": 0, "text": "This is an example sentence"},
54 | {"text_id": 1, "text": "This is another example sentence"},
55 | ]
56 | with tempfile.TemporaryDirectory() as tempdir:
57 | with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as f:
58 | json.dump(text_rows, f)
59 | f.seek(0)
60 | encoding_file_path = os.path.join(tempdir, "encoding.pkl")
61 | encoded = encode_text(
62 | TINY_MODEL_NAME, file_to_encode=f.name, encoding_file=encoding_file_path
63 | )
64 | assert encoded.shape == (2, 128)
65 | encoded_vectors, encoded_indices = pickle.load(
66 | open(encoding_file_path, "rb")
67 | )
68 | assert (encoded == encoded_vectors).all()
69 | assert encoded_indices == [0, 1]
70 | gc.collect()
71 |
72 |
73 | def test_encode_text_error_from_no_string_or_file():
74 | """Test that either a string or a file must be passed to encode."""
75 | with pytest.raises(ValueError):
76 | _ = encode_text(TINY_MODEL_NAME)
77 | gc.collect()
78 |
79 |
80 | def test_encode_text_error_from_both_string_and_file():
81 | """Test that either a string or a file, but not both, must be passed to encode."""
82 | text = "This is an example sentence"
83 | file = "/tmp/test.txt"
84 | with pytest.raises(ValueError):
85 | _ = encode_text(TINY_MODEL_NAME, file_to_encode=file, text_to_encode=text)
86 | gc.collect()
87 |
88 |
89 | def test_retrieve_objects():
90 | """Test retrieval against a list of vectors."""
91 | mock_query_vector = np.array([[0.0, 0.0, 1.0, 0.0]])
92 | # The query vector matches the third row in the search collection.
93 | mock_search_collection = np.array(
94 | [
95 | [1.0, 0.0, 0.0, 0.0],
96 | [0.0, 1.0, 0.0, 0.0],
97 | [0.0, 0.0, 1.0, 0.0],
98 | [0.0, 0.0, 0.0, 1.0],
99 | ]
100 | )
101 | document_names = ["a", "b", "c", "d"]
102 | mock_vector_indices = [0, 1, 2, 3]
103 | with tempfile.TemporaryDirectory() as tmpdir:
104 | search_index_pickle = os.path.join(tmpdir, "search_index.pkl")
105 | pickle.dump(
106 | (mock_search_collection, mock_vector_indices),
107 | open(search_index_pickle, "wb"),
108 | )
109 | results = retrieve_objects(
110 | mock_query_vector, search_index_pickle, document_names, depth=3
111 | )
112 | assert (
113 | len(results) == 3
114 | ), "The number of results should match the provided depth."
115 |
116 | # Verify that the index of the first retrieved document matches the document
117 | # that we known matches the query vector.
118 | first_retrieved_document, _ = results[0]
119 | assert first_retrieved_document == "c"
120 |
121 | # Verify that the first retrieved document has the greatest retrieval score.
122 | sorted_results = sorted(results, key=lambda x: x[1], reverse=True)
123 | assert sorted_results[0][0] == first_retrieved_document
124 | gc.collect()
125 |
--------------------------------------------------------------------------------