├── .github └── workflows │ └── stale.yml ├── .gitignore ├── .readthedocs.yaml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── docs ├── Makefile ├── _static │ └── custom.css ├── conf.py ├── examples │ ├── add_new_modules.md │ ├── basic.md │ ├── dyval.md │ ├── multimodal.md │ ├── prompt_attack.md │ └── prompt_engineering.md ├── index.rst ├── leaderboard │ ├── advprompt.md │ ├── dyval.md │ └── pe.md ├── make.bat ├── reference │ ├── dataload │ │ ├── dataload.rst │ │ ├── dataset.rst │ │ └── index.rst │ ├── dyval │ │ ├── DAG │ │ │ ├── code_dag.rst │ │ │ ├── dag.rst │ │ │ ├── describer.rst │ │ │ ├── index.rst │ │ │ ├── logic_dag.rst │ │ │ └── math_dag.rst │ │ ├── dyval_dataset.rst │ │ ├── dyval_utils.rst │ │ └── index.rst │ ├── index.rst │ ├── metrics │ │ ├── eval.rst │ │ └── index.rst │ ├── models │ │ ├── index.rst │ │ └── models.rst │ └── utils │ │ ├── dataprocess.rst │ │ └── index.rst ├── requirements.txt └── start │ ├── installation.md │ └── intro.md ├── examples ├── add_new_modules.md ├── basic.ipynb ├── dyval.ipynb ├── efficient_multi_prompt_eval.ipynb ├── mpa.ipynb ├── multimodal.ipynb ├── prompt_attack.ipynb └── prompt_engineering.ipynb ├── imgs ├── prompt_attack_attention.png ├── promptbench.png └── promptbench_logo.png ├── promptbench ├── __init__.py ├── config.py ├── dataload │ ├── __init__.py │ ├── dataload.py │ └── dataset.py ├── dyval │ ├── DAG │ │ ├── __init__.py │ │ ├── code_dag.py │ │ ├── dag.py │ │ ├── describer.py │ │ ├── logic_dag.py │ │ └── math_dag.py │ ├── __init__.py │ ├── dyval_dataset.py │ └── dyval_utils.py ├── metrics │ ├── __init__.py │ ├── bleu │ │ ├── bleu.py │ │ ├── bleu_.py │ │ └── tokenizer_13a.py │ ├── cider │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── eval.py │ ├── squad_v2 │ │ ├── compute_score.py │ │ └── squad_v2.py │ └── vqa │ │ └── eval_vqa.py ├── models │ ├── __init__.py │ └── models.py ├── mpa │ ├── .DS_Store │ ├── __init__.py │ ├── agent.py │ ├── dataprocess.py │ └── mpa_prompts.py ├── prompt_attack │ ├── README.md │ ├── __init__.py │ ├── attack.py │ ├── goal_function.py │ ├── label_constraint.py │ ├── search.py │ └── transformations.py ├── prompt_engineering │ ├── __init__.py │ ├── base.py │ ├── chain_of_thought.py │ ├── emotion_prompt.py │ ├── expert_prompting.py │ ├── generated_knowledge.py │ └── least_to_most.py ├── prompteval │ ├── __init__.py │ ├── efficient_eval.py │ └── methods.py ├── prompts │ ├── __init__.py │ ├── adv_prompts │ │ ├── Readme.md │ │ ├── chatgpt_fewshot.md │ │ ├── chatgpt_zeroshot.md │ │ ├── t5_fewshot.md │ │ ├── t5_zeroshot.md │ │ ├── ul2_fewshot.md │ │ ├── ul2_zeroshot.md │ │ ├── vicuna_fewshot.md │ │ └── vicuna_zeroshot.md │ ├── few_shot_examples.yaml │ ├── method_oriented.py │ ├── prompt.py │ ├── role_oriented.py │ ├── semantic_atk_prompts.py │ └── task_oriented.py └── utils │ ├── __init__.py │ ├── dataprocess.py │ ├── defense.py │ └── visualize.py ├── requirements.txt └── setup.py /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. 2 | # 3 | # You can adjust the behavior by modifying this file. 4 | # For more information, see: 5 | # https://github.com/actions/stale 6 | name: Mark stale issues and pull requests 7 | 8 | on: 9 | schedule: 10 | - cron: '28 6 * * *' 11 | 12 | jobs: 13 | stale: 14 | 15 | runs-on: ubuntu-latest 16 | permissions: 17 | issues: write 18 | pull-requests: write 19 | 20 | steps: 21 | - uses: actions/stale@v5 22 | with: 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | stale-issue-message: 'Stale issue message' 25 | stale-pr-message: 'Stale pull request message' 26 | stale-issue-label: 'no-issue-activity' 27 | stale-pr-label: 'no-pr-activity' 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *results*/ 2 | logs*/ 3 | *.sh 4 | promptbench/data/ 5 | llama*.yaml 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | amlt.yaml 168 | test_pe.py 169 | test_pe.sh 170 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.8" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: docs/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | # formats: 24 | # - pdf 25 | # - epub 26 | 27 | # Optional but recommended, declare the Python requirements required 28 | # to build your documentation 29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 30 | python: 31 | install: 32 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* CSS rules to make .docutils.align-default tables scrollable */ 2 | .docutils.align-default { 3 | overflow-x: auto; 4 | display: block; 5 | max-width: 100%; 6 | } 7 | 8 | .docutils.align-default table { 9 | white-space: normal; 10 | width: auto; 11 | display: block; 12 | } 13 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | import sphinx_rtd_theme 16 | 17 | sys.path.insert(0, os.path.abspath('..')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'promptbench' 23 | copyright = '2023, Kaijie Zhu' 24 | author = 'Kaijie Zhu' 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = '0.0.1' 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | html_theme = "sphinx_rtd_theme" 36 | extensions = [ 37 | 'sphinx.ext.autodoc', # For docstring support 38 | 'myst_parser', 39 | 'sphinx.ext.napoleon', # For Google and NumPy style docstrings 40 | 'sphinx_autodoc_typehints', # For type hints support 41 | ] 42 | 43 | 44 | # Add any paths that contain templates here, relative to this directory. 45 | templates_path = ['_templates'] 46 | 47 | # List of patterns, relative to source directory, that match files and 48 | # directories to ignore when looking for source files. 49 | # This pattern also affects html_static_path and html_extra_path. 50 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 51 | 52 | 53 | # -- Options for HTML output ------------------------------------------------- 54 | 55 | # The theme to use for HTML and HTML Help pages. See the documentation for 56 | # a list of builtin themes. 57 | # 58 | # html_theme = 'alabaster' 59 | 60 | # Add any paths that contain custom static files (such as style sheets) here, 61 | # relative to this directory. They are copied after the builtin static files, 62 | # so a file named "default.css" will overwrite the builtin "default.css". 63 | html_static_path = ['_static'] 64 | 65 | html_css_files = [ 66 | 'custom.css', 67 | ] 68 | 69 | html_theme = "sphinx_rtd_theme" 70 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] -------------------------------------------------------------------------------- /docs/examples/add_new_modules.md: -------------------------------------------------------------------------------- 1 | # Adding new modules 2 | 3 | Each module in promptbench can be easily extended. In the following, we provide basic guidelines for customizing your own datasets, models, prompt engineering methods, and evaluation metrics. 4 | 5 | ## Add new datasets 6 | Adding new datasets involves two steps: 7 | 8 | - Implementing a New Dataset Class: Datasets are supposed to be implemented in `dataload/dataset.py` and inherit from the `Dataset` class. For your custom dataset, implement the `__init__` method to load your dataset. We recommend organizing your data samples as dictionaries to facilitate the input process. 9 | - Adding an Interface: After customizing the dataset class, register it in the `DataLoader` class within `dataload.py`. 10 | 11 | 12 | 13 | ## Add new models 14 | Similar to adding new datasets, the addition of new models also consists of two steps. 15 | - Implementing a New Model Class: Models should be implemented in `dataload/model.py`, inheriting from the `LLMModel` class. In your customized model, you should implement `self.tokenizer` and `self.model`. You may also customize your own `predict` function for inference. If the `predict` function is not customized, the default `predict` function inherited from `LLMModel` will be used. 16 | - Adding an Interface: After customizing the model class, register it in the `_create_model` function within the `class LLMModel` in `__init__.py`. 17 | 18 | 19 | 20 | ## Add new prompt engineering methods 21 | Adding new methods in prompt engineering is similar to steps of C.1 and C.2. 22 | 23 | - Implementing a New Methods Class: Methods should be implemented in \\ `prompt\_engineering` Module. Firstly, create a new `.py` file for your methods. 24 | Then implement two functions: `\_\_init\_\_` and `query`. For unified management, two points need be noticed: 1. all methods should inherits from `Base` class that has common code for prompt engineering methods. 2. prompts used in methods should be stored in `prompts/method\_oriented.py`. 25 | - Adding an Interface: After implementing a new methods, register it in the `METHOD\_MAP` that is used to map method names to their corresponding class. 26 | 27 | 28 | ## Add new metrics and input/output process functions 29 | New evaluation metrics should be implemented as static functions in `class Eval` within the `metrics` module. Similarly, new input/output process functions should be implemented as static functions in `class InputProcess` and `class OutputProcess` in the `utils` module. -------------------------------------------------------------------------------- /docs/examples/basic.md: -------------------------------------------------------------------------------- 1 | # Basic Usage 2 | 3 | This example will walk you throught the basic usage of PromptBench. We hope that you can get familiar with the APIs and use it in your own projects later. 4 | 5 | First, there is a unified import of `import promptbench as pb` that easily imports the package. 6 | 7 | 8 | ```python 9 | import promptbench as pb 10 | ``` 11 | 12 | /home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html 13 | from .autonotebook import tqdm as notebook_tqdm 14 | 15 | 16 | ## Load dataset 17 | 18 | First, PromptBench supports easy load of datasets. 19 | 20 | 21 | ```python 22 | # print all supported datasets in promptbench 23 | print('All supported datasets: ') 24 | print(pb.SUPPORTED_DATASETS) 25 | 26 | # load a dataset, sst2, for instance. 27 | # if the dataset is not available locally, it will be downloaded automatically. 28 | dataset = pb.DatasetLoader.load_dataset("sst2") 29 | # dataset = pb.DatasetLoader.load_dataset("mmlu") 30 | # dataset = pb.DatasetLoader.load_dataset("un_multi") 31 | # dataset = pb.DatasetLoader.load_dataset("iwslt2017", ["ar-en", "de-en", "en-ar"]) 32 | # dataset = pb.DatasetLoader.load_dataset("math", "algebra__linear_1d") 33 | # dataset = pb.DatasetLoader.load_dataset("bool_logic") 34 | # dataset = pb.DatasetLoader.load_dataset("valid_parenthesesss") 35 | 36 | # print the first 5 examples 37 | dataset[:5] 38 | ``` 39 | 40 | All supported datasets: 41 | ['sst2', 'cola', 'qqp', 'mnli', 'mnli_matched', 'mnli_mismatched', 'qnli', 'wnli', 'rte', 'mrpc', 'mmlu', 'squad_v2', 'un_multi', 'iwslt2017', 'math', 'bool_logic', 'valid_parentheses', 'gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking', 'last_letter_concat', 'numersense', 'qasc'] 42 | 43 | 44 | 45 | 46 | 47 | [{'content': "it 's a charming and often affecting journey . ", 'label': 1}, 48 | {'content': 'unflinchingly bleak and desperate ', 'label': 0}, 49 | {'content': 'allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker . ', 50 | 'label': 1}, 51 | {'content': "the acting , costumes , music , cinematography and sound are all astounding given the production 's austere locales . ", 52 | 'label': 1}, 53 | {'content': "it 's slow -- very , very slow . ", 'label': 0}] 54 | 55 | 56 | 57 | ## Load models 58 | 59 | Then, you can easily load LLM models via promptbench. 60 | 61 | 62 | ```python 63 | # print all supported models in promptbench 64 | print('All supported models: ') 65 | print(pb.SUPPORTED_MODELS) 66 | 67 | # load a model, flan-t5-large, for instance. 68 | model = pb.LLMModel(model='google/flan-t5-large', max_new_tokens=10, temperature=0.0001, device='cuda') 69 | # model = pb.LLMModel(model='llama2-13b-chat', max_new_tokens=10, temperature=0.0001) 70 | ``` 71 | 72 | All supported models: 73 | ['google/flan-t5-large', 'llama2-7b', 'llama2-7b-chat', 'llama2-13b', 'llama2-13b-chat', 'llama2-70b', 'llama2-70b-chat', 'phi-1.5', 'phi-2', 'palm', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-1106-preview', 'gpt-3.5-turbo-1106', 'vicuna-7b', 'vicuna-13b', 'vicuna-13b-v1.3', 'google/flan-ul2', 'gemini-pro'] 74 | 75 | 76 | You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 77 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 78 | 79 | 80 | ## Construct prompts 81 | 82 | Prompts are the key interaction interface to LLMs. You can easily construct a prompt by call the Prompt API. 83 | 84 | 85 | ```python 86 | # Prompt API supports a list, so you can pass multiple prompts at once. 87 | prompts = pb.Prompt(["Classify the sentence as positive or negative: {content}", 88 | "Determine the emotion of the following sentence as positive or negative: {content}" 89 | ]) 90 | ``` 91 | 92 | You may need to define the projection function for the model output. 93 | Since the output format defined in your prompts may be different from the model output. 94 | For example, for sst2 dataset, the label are '0' and '1' to represent 'negative' and 'positive'. 95 | But the model output is 'negative' and 'positive'. 96 | So we need to define a projection function to map the model output to the label. 97 | 98 | 99 | ```python 100 | def proj_func(pred): 101 | mapping = { 102 | "positive": 1, 103 | "negative": 0 104 | } 105 | return mapping.get(pred, -1) 106 | ``` 107 | 108 | ## Perform evaluation using prompts, datasets, and models 109 | 110 | Finally, you can perform standard evaluation using the loaded prompts, datasets, and labels. 111 | 112 | 113 | ```python 114 | from tqdm import tqdm 115 | for prompt in prompts: 116 | preds = [] 117 | labels = [] 118 | for data in tqdm(dataset): 119 | # process input 120 | input_text = pb.InputProcess.basic_format(prompt, data) 121 | label = data['label'] 122 | raw_pred = model(input_text) 123 | # process output 124 | pred = pb.OutputProcess.cls(raw_pred, proj_func) 125 | preds.append(pred) 126 | labels.append(label) 127 | 128 | # evaluate 129 | score = pb.Eval.compute_cls_accuracy(preds, labels) 130 | print(f"{score:.3f}, {prompt}") 131 | ``` 132 | 133 | 0%| | 0/872 [00:00. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 71 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 72 | 73 | 74 | 75 | ```python 76 | # for each prompt, evaluate the score 77 | for prompt in prompts: 78 | score = {} 79 | 80 | # three orders of the dataset: topological, reversed, random 81 | for order in ["topological"]: 82 | # for order in ["topological", "reversed", "random"]: 83 | data = dataset[order] 84 | preds = [] 85 | answers = [] 86 | 87 | for d in data[:1]: 88 | input_text = pb.InputProcess.basic_format(prompt, d) 89 | raw_pred = model(input_text) 90 | 91 | # dyval preds are processed differently, please refer to the source code /promptbench/dyval/dyval_utils.py 92 | pred = process_dyval_preds(raw_pred) 93 | preds.append(pred) 94 | answers.append(d["answers"]) 95 | 96 | print(f"Input: {input_text}") 97 | print(f"Raw Pred: {raw_pred}") 98 | print(f"Pred: {pred}") 99 | print(f"Answer: {d['answers']}") 100 | print("\n") 101 | 102 | score[order] = dyval_evaluate(dataset.dataset_type, preds, answers) 103 | 104 | print(score) 105 | ``` 106 | 107 | /home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.8/site-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. 108 | warnings.warn( 109 | 2023-11-29 05:05:39.043831: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA 110 | To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 111 | 112 | 113 | Input: Here is a description of an arithmetic problem: 114 | The value of aae is 10. 115 | The value of aab is 10. 116 | The value of aaa is 8. 117 | aac gets its value by dividing the value of aaa by those of aab. 118 | The value of aad is 6. 119 | aaf gets its value by multiplying together the value of aad and aae. 120 | aag gets its value by multiplying together the value of aac and aaf. 121 | Compute the result of aag. If the solution cannot be calculated, answer 'N/A'. Ensure your result is within a relative precision of 0.0001 (or 0.01%) compared to the ground truth value. Ensure your final result begins with '<<<' and ends with '>>>', for example, if the answer is 1, your final result should be <<<1>>>. 122 | Raw Pred: Answer: 1>>> 123 | Pred: 124 | Answer: 48.0 125 | 126 | 127 | {'topological': 0.0} 128 | 129 | -------------------------------------------------------------------------------- /docs/examples/multimodal.md: -------------------------------------------------------------------------------- 1 | # Multi-Modal Models 2 | 3 | This example will walk you throught the basic usage of MULTI-MODAL models in PromptBench. We hope that you can get familiar with the APIs and use it in your own projects later. 4 | 5 | First, there is a unified import of `import promptbench as pb` that easily imports the package. 6 | 7 | 8 | ```python 9 | import promptbench as pb 10 | ``` 11 | 12 | /anaconda/envs/promptbench_1/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html 13 | from .autonotebook import tqdm as notebook_tqdm 14 | 15 | 16 | ## Load dataset 17 | 18 | First, PromptBench supports easy load of datasets. 19 | 20 | 21 | ```python 22 | # print all supported datasets in promptbench 23 | print('All supported datasets: ') 24 | print(pb.SUPPORTED_DATASETS_VLM) 25 | 26 | # load a dataset, MMMMU, for instance. 27 | # if the dataset is not available locally, it will be downloaded automatically. 28 | dataset = pb.DatasetLoader.load_dataset("mmmu") 29 | 30 | # print the first 5 examples 31 | for idx in range(5): 32 | print(dataset[idx]) 33 | ``` 34 | 35 | All supported datasets: 36 | ['vqav2', 'nocaps', 'science_qa', 'math_vista', 'ai2d', 'mmmu', 'chart_qa'] 37 | Images already saved to local, loading file: /home/v-mingxia/promptbench/promptbench/data/mmmu/validation.json 38 | {'images': [], 'image_paths': ['/home/v-mingxia/promptbench/promptbench/data/mmmu/validation/0_image_1.png'], 'answer': 'B', 'question': ' Baxter Company has a relevant range of production between 15,000 and 30,000 units. The following cost data represents average variable costs per unit for 25,000 units of production. If 30,000 units are produced, what are the per unit manufacturing overhead costs incurred?\nA: $6\nB: $7\nC: $8\nD: $9'} 39 | {'images': [], 'image_paths': ['/home/v-mingxia/promptbench/promptbench/data/mmmu/validation/1_image_1.png'], 'answer': 'C', 'question': 'Assume accounts have normal balances, solve for the one missing account balance: Dividends. Equipment was recently purchased, so there is neither depreciation expense nor accumulated depreciation. \nA: $194,815\nB: $182,815\nC: $12,000\nD: $9,000'} 40 | {'images': [], 'image_paths': ['/home/v-mingxia/promptbench/promptbench/data/mmmu/validation/2_image_1.png'], 'answer': 'B', 'question': 'Maxwell Software, Inc., has the following mutually exclusive projects.Suppose the company uses the NPV rule to rank these two projects. Which project should be chosen if the appropriate discount rate is 15 percent?\nA: Project A\nB: Project B'} 41 | {'images': [], 'image_paths': ['/home/v-mingxia/promptbench/promptbench/data/mmmu/validation/3_image_1.png'], 'answer': 'D', 'question': "Each situation below relates to an independent company's Owners' Equity. Calculate the missing values of company 2.\nA: $1,620\nB: $12,000\nC: $51,180\nD: $0"} 42 | {'images': [], 'image_paths': ['/home/v-mingxia/promptbench/promptbench/data/mmmu/validation/4_image_1.png'], 'answer': 'B', 'question': 'The following data show the units in beginning work in process inventory, the number of units started, the number of units transferred, and the percent completion of the ending work in process for conversion. Given that materials are added at the beginning of the process, what are the equivalent units for conversion costs for each quarter using the weighted-average method? Assume that the quarters are independent.\nA: 132,625\nB: 134,485\nC: 135,332\nD: 132,685'} 43 | 44 | 45 | ## Load models 46 | 47 | Then, you can easily load VLM models via promptbench. 48 | 49 | 50 | ```python 51 | # print all supported models in promptbench 52 | print('All supported models: ') 53 | print(pb.SUPPORTED_MODELS_VLM) 54 | 55 | # load a model, llava-1.5-7b, for instance. 56 | model = pb.VLMModel(model='llava-hf/llava-1.5-7b-hf', max_new_tokens=2048, temperature=0.0001, device='cuda') 57 | ``` 58 | 59 | All supported models: 60 | ['Salesforce/blip2-opt-2.7b', 'Salesforce/blip2-opt-6.7b', 'Salesforce/blip2-flan-t5-xl', 'Salesforce/blip2-flan-t5-xxl', 'llava-hf/llava-1.5-7b-hf', 'llava-hf/llava-1.5-13b-hf', 'gemini-pro-vision', 'gpt-4-vision-preview', 'Qwen/Qwen-VL', 'Qwen/Qwen-VL-Chat', 'qwen-vl-plus', 'qwen-vl-max', 'internlm/internlm-xcomposer2-vl-7b'] 61 | 62 | 63 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 64 | Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00, 1.48s/it] 65 | 66 | 67 | ## Construct prompts 68 | 69 | Prompts are the key interaction interface to VLMs. You can easily construct a prompt by call the Prompt API. 70 | 71 | 72 | ```python 73 | # Prompt API supports a list, so you can pass multiple prompts at once. 74 | prompts = pb.Prompt([ 75 | "You are a helpful assistant. Here is the question:{question}\nANSWER:", 76 | "USER:{question}\nANSWER:", 77 | ]) 78 | ``` 79 | 80 | ## Perform evaluation using prompts, datasets, and models 81 | 82 | Finally, you can perform standard evaluation using the loaded prompts, datasets, and labels. 83 | 84 | 85 | ```python 86 | from tqdm import tqdm 87 | for prompt in prompts: 88 | preds = [] 89 | labels = [] 90 | for data in tqdm(dataset): 91 | # process input 92 | input_text = pb.InputProcess.basic_format(prompt, data) 93 | input_images = data['images'] # please use data['image_paths'] instead of data['images'] for models that only support image path/url, such as GPT-4v 94 | label = data['answer'] 95 | raw_pred = model(input_images, input_text) 96 | # process output 97 | pred = pb.OutputProcess.pattern_split(raw_pred, 'ANSWER:') 98 | preds.append(pred) 99 | labels.append(label) 100 | 101 | # evaluate 102 | score = pb.Eval.compute_cls_accuracy(preds, labels) 103 | print(f"{score:.3f}, {repr(prompt)}") 104 | ``` 105 | 106 | 0%| | 0/900 [00:00. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 72 | Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. 73 | 74 | 75 | ['textbugger', 'deepwordbug', 'textfooler', 'bertattack', 'checklist', 'stresstest', 'semantic'] 76 | 77 | 78 | 79 | ```python 80 | # create attack, specify the model, dataset, prompt, evaluation function, and unmodifiable words 81 | # verbose=True means that the attack will print the intermediate results 82 | attack = Attack(model_t5, "stresstest", dataset, prompt, eval_func, unmodifiable_words, verbose=True) 83 | 84 | # print attack result 85 | print(attack.attack()) 86 | ``` 87 | 88 | These words (if they appear in the prompt) are not allowed to be attacked: 89 | ["positive'", "negative'", 'content'] 90 | 91 | 92 | /home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. 93 | warnings.warn( 94 | 95 | 96 | -------------------------------------------------- 97 | Current prompt is: As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: 98 | Question: {content} 99 | Answer: 100 | Current accuracy is: 1.0 101 | -------------------------------------------------- 102 | 103 | -------------------------------------------------- 104 | Modifiable words: ['As', 'a', 'sentiment', 'classifier', 'determine', 'whether', 'the', 'following', 'text', 'is', 'or', 'Please', 'classify', 'Question', 'Answer'] 105 | -------------------------------------------------- 106 | 107 | -------------------------------------------------- 108 | Current prompt is: As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: 109 | Question: {content} 110 | Answer and false is not true : 111 | Current accuracy is: 1.0 112 | -------------------------------------------------- 113 | 114 | -------------------------------------------------- 115 | Current prompt is: As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: 116 | Question: {content} 117 | Answer and true is true and true is true and true is true and true is true and true is true : 118 | Current accuracy is: 1.0 119 | -------------------------------------------------- 120 | 121 | -------------------------------------------------- 122 | Current prompt is: As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: 123 | Question: {content} 124 | Answer and true is true : 125 | Current accuracy is: 1.0 126 | -------------------------------------------------- 127 | 128 | {'original prompt': "As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \nQuestion: {content}\nAnswer:", 'original score': 1.0, 'attacked prompt': "As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \nQuestion: {content}\nAnswer and false is not true :", 'attacked score': 1.0, 'PDR': 0.0} 129 | 130 | 131 | 132 | ```python 133 | 134 | ``` 135 | -------------------------------------------------------------------------------- /docs/examples/prompt_engineering.md: -------------------------------------------------------------------------------- 1 | # Prompt Engineering 2 | 3 | This example will walk you through the basic usage of PromptBench. We hope that you can get familiar with the APIs and use it in your own projects later. 4 | 5 | First, there is a unified import of `import promptbench as pb` that easily imports the package. 6 | 7 | 8 | ```python 9 | import promptbench as pb 10 | ``` 11 | 12 | ## Load dataset 13 | 14 | First, PromptBench supports easy load of datasets. 15 | 16 | 17 | ```python 18 | # print all supported datasets in promptbench 19 | print('All supported datasets: ') 20 | print(pb.SUPPORTED_DATASETS) 21 | 22 | # load a dataset, sst2, for instance. 23 | # if the dataset is not available locally, it will be downloaded automatically. 24 | dataset_name = "gsm8k" 25 | dataset = pb.DatasetLoader.load_dataset(dataset_name) 26 | 27 | # print the first 5 examples 28 | dataset[:5] 29 | ``` 30 | 31 | All supported datasets: 32 | ['cola', 'sst2', 'qqp', 'mnli', 'mnli_matched', 'mnli_mismatched', 'qnli', 'wnli', 'rte', 'mrpc', 'mmlu', 'squad_v2', 'un_multi', 'iwslt', 'math', 'bool_logic', 'valid_parentheses', 'gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking'] 33 | 34 | 35 | 36 | 37 | 38 | [{'content': "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", 39 | 'label': '18'}, 40 | {'content': 'A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?', 41 | 'label': '3'}, 42 | {'content': 'Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?', 43 | 'label': '70000'}, 44 | {'content': 'James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?', 45 | 'label': '540'}, 46 | {'content': "Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy. She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed. In the afternoon, she gives her chickens another 25 cups of feed. How many cups of feed does she need to give her chickens in the final meal of the day if the size of Wendi's flock is 20 chickens?", 47 | 'label': '20'}] 48 | 49 | 50 | 51 | ## Load models 52 | 53 | Then, you can easily load LLM models via promptbench. 54 | 55 | 56 | ```python 57 | # print all supported models in promptbench 58 | print('All supported models: ') 59 | print(pb.SUPPORTED_MODELS) 60 | 61 | # load a model, gpt-3.5-turbo, for instance. 62 | # If model is openai/palm, need to provide openai_key/palm_key 63 | # If model is llama, vicuna, need to provide model dir 64 | model = pb.LLMModel(model='gpt-3.5-turbo', 65 | openai_key = 'sk-xxx', 66 | max_new_tokens=150) 67 | ``` 68 | 69 | All supported models: 70 | ['google/flan-t5-large', 'llama2-7b', 'llama2-7b-chat', 'llama2-13b', 'llama2-13b-chat', 'llama2-70b', 'llama2-70b-chat', 'phi-1.5', 'palm', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-1106-preview', 'gpt-3.5-turbo-1106', 'vicuna-7b', 'vicuna-13b', 'vicuna-13b-v1.3', 'google/flan-ul2'] 71 | 72 | 73 | You can use different methods to predict 74 | 75 | 76 | ```python 77 | # load method 78 | # print all methods and their supported datasets 79 | print('All supported methods: ') 80 | print(pb.SUPPORTED_METHODS) 81 | print('Supported datasets for each method: ') 82 | print(pb.METHOD_SUPPORT_DATASET) 83 | 84 | # load a method, emotion_prompt, for instance. 85 | method = pb.PEMethod(method='emotion_prompt', 86 | dataset=dataset_name, 87 | verbose=True, # if True, print the detailed prompt and response 88 | prompt_id = 1 # for emotion_prompt 89 | ) 90 | ``` 91 | 92 | All supported methods: 93 | ['CoT', 'ZSCoT', 'least_to_most', 'generated_knowledge', 'expert_prompting', 'emotion_prompt', 'baseline'] 94 | Supported datasets for each method: 95 | {'CoT': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking'], 'ZSCoT': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking'], 'expert_prompting': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking'], 'emotion_prompt': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking'], 'least_to_most': ['gsm8k', 'last_letter_concat'], 'generated_knowledge': ['csqa', 'numersense', 'qasc'], 'baseline': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking', 'last_letter_concat', 'numersense', 'qasc']} 96 | 97 | 98 | 99 | ```python 100 | results = method.test(dataset, 101 | model, 102 | num_samples=5 # if don't set the num_samples, method will use all examples in the dataset 103 | ) 104 | 105 | results 106 | ``` 107 | 108 | 0%| | 1/1319 [00:03<1:11:51, 3.27s/it] 109 | 110 | Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? 111 | This is very important to my career. 112 | Please output your answer at the end as ## 113 | 114 | Janet eats 3 eggs for breakfast every morning, so she has 16 - 3 = <<16-3=13>>13 eggs remaining. 115 | She bakes muffins with 4 eggs, so she has 13 - 4 = <<13-4=9>>9 eggs remaining. 116 | She sells these 9 eggs at the farmers' market for a total of 9 * $2 = $<<9*2=18>>18. Answer: ##18 117 | 118 | 119 | 0%| | 2/1319 [00:05<1:01:19, 2.79s/it] 120 | 121 | A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take? 122 | This is very important to my career. 123 | Please output your answer at the end as ## 124 | 125 | A robe takes 2 bolts of blue fiber and half that much white fiber. This means it takes 2 bolts of blue fiber and 1 bolt of white fiber (half of 2 bolts). Therefore, in total, it takes 2 + 1 = 3 bolts of fiber. 126 | 127 | So, the answer is ##3. 128 | 129 | 130 | 0%| | 3/1319 [00:09<1:14:10, 3.38s/it] 131 | 132 | Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make? 133 | This is very important to my career. 134 | Please output your answer at the end as ## 135 | 136 | The increase in value of the house after repairs is 150% of the original value, which is 150/100 * $80,000 = $<<150/100*80000=120000>>120,000. 137 | So the total value of the house after repairs is $80,000 + $120,000 = $<<80000+120000=200000>>200,000. 138 | Therefore, Josh made a profit of $200,000 - ($80,000 + $50,000) = $<<200000-(80000+50000)=70000>>70,000. 139 | Answer: ##70,000 140 | 141 | 142 | 0%| | 4/1319 [00:12<1:11:11, 3.25s/it] 143 | 144 | James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week? 145 | This is very important to my career. 146 | Please output your answer at the end as ## 147 | 148 | James runs 60 meters per sprint and he runs 3 sprints per session, so he runs 60*3 = <<60*3=180>>180 meters per session. 149 | He runs 3 sessions per week, so he runs 180*3 = <<180*3=540>>540 meters per week. Answer: ##540. 150 | 151 | 152 | 0%| | 5/1319 [00:17<1:15:20, 3.44s/it] 153 | 154 | Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy. She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed. In the afternoon, she gives her chickens another 25 cups of feed. How many cups of feed does she need to give her chickens in the final meal of the day if the size of Wendi's flock is 20 chickens? 155 | This is very important to my career. 156 | Please output your answer at the end as ## 157 | 158 | In the morning, Wendi gives each chicken 15/20 = <<15/20=0.75>>0.75 cups of feed. 159 | In the afternoon, Wendi gives each chicken 25/20 = <<25/20=1.25>>1.25 cups of feed. 160 | So far, Wendi has given each chicken a total of 0.75 + 1.25 = <<0.75+1.25=2>>2 cups of feed. 161 | Therefore, Wendi needs to give her chickens a total of 3 - 2 = <<3-2=1>>1 cup of feed in the final meal of the day. Answer: \boxed{1}. 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 0.4 171 | 172 | 173 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. promptbench documentation master file, created by 2 | sphinx-quickstart on Fri Nov 24 04:09:16 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to promptbench's documentation! 7 | ======================================= 8 | 9 | **PromptBench** is a unified library for evaluating and understanding large language models. Please refer to `PromptBench `_ for the code. 10 | 11 | .. image:: ../imgs/promptbench.png 12 | :align: center 13 | :width: 75% 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | :caption: Get Started 18 | 19 | start/intro 20 | start/installation 21 | 22 | .. toctree:: 23 | :maxdepth: 1 24 | :caption: Examples 25 | 26 | examples/basic 27 | examples/multimodal 28 | examples/dyval 29 | examples/prompt_attack 30 | examples/prompt_engineering 31 | examples/add_new_modules 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | :caption: Reference 36 | 37 | reference/dataload/index 38 | reference/dyval/index 39 | reference/metrics/index 40 | reference/models/index 41 | reference/utils/index 42 | 43 | .. toctree:: 44 | :maxdepth: 1 45 | :caption: Leaderboards 46 | 47 | leaderboard/advprompt 48 | leaderboard/dyval 49 | leaderboard/pe 50 | 51 | 52 | Indices and tables 53 | ================== 54 | 55 | * :ref:`genindex` 56 | * :ref:`modindex` 57 | * :ref:`search` 58 | -------------------------------------------------------------------------------- /docs/leaderboard/advprompt.md: -------------------------------------------------------------------------------- 1 | # Adversarial Prompt Leaderboard 2 | 3 | PromptBench can evaluate the adversarial robustness of LLMs to prompts. More information can be found at [PromptBench: Towards Evaluating the Robustness of Large Language Models on Adversarial Prompts](https://arxiv.org/abs/2306.04528). 4 | 5 | Please contact us if you want the results of your models shown in this leaderboard. 6 | 7 | 8 | [[All results of LLMs](#all-results-of-llms)] [[All results of Prompts](#all-results-of-prompts)] [[View by Models](#attack-results-view-by-models)] [[View by Datasets](#attack-results-view-by-datasets)] 9 | 10 | 11 | 12 | 13 | ### All Results of LLMs 14 | 15 | | Model | SST-2 | CoLA | QQP | MPRC | MNLI | QNLI | RTE | WNLI | MMLU | SQuAD v2 | IWSLT | UN Multi | Math | Avg | 16 | |:--------:|:---------:|:---------:|:---------:|:---------:|:----------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:----------:|:---------:|:---------:| 17 | | T5-Large | 0.04±0.11 | 0.16±0.19 | 0.09±0.15 | 0.17±0.26 | 0.08±0.13 | 0.33±0.25 | 0.08±0.13 | 0.13±0.14 | 0.11±0.18 | 0.05±0.12 | 0.14±0.17 | 0.13±0.14 | 0.24±0.21 | 0.13±0.19 | 18 | | Vicuna | 0.83±0.26 | 0.81±0.22 | 0.51±0.41 | 0.52±0.40 | 0.67±0.38 | 0.87±0.19 | 0.78±0.23 | 0.78±0.27 | 0.41±0.24 | - | - | - | - | 0.69±0.34 | 19 | | LLaMA2 | 0.24±0.33 | 0.38±0.32 | 0.59±0.33 | 0.84±0.27 | 0.32±0.32 | 0.51±0.39 | 0.68±0.39 | 0.73±0.37 | 0.28±0.24 | - | - | - | - | 0.51±0.39 | 20 | | UL2 | 0.03±0.12 | 0.13±0.20 | 0.02±0.04 | 0.06±0.10 | 0.06±0.12 | 0.05±0.11 | 0.02±0.04 | 0.04±0.03 | 0.05±0.11 | 0.10±0.18 | 0.15±0.11 | 0.05±0.05 | 0.21±0.21 | 0.08±0.14 | 21 | | ChatGPT | 0.17±0.29 | 0.21±0.31 | 0.16±0.30 | 0.22±0.29 | 0.13±0.18 | 0.25±0.31 | 0.09±0.13 | 0.14±0.12 | 0.14±0.18 | 0.22±0.28 | 0.17±0.26 | 0.12±0.18 | 0.33±0.31 | 0.18±0.26 | 22 | | GPT-4 | 0.24±0.38 | 0.13±0.23 | 0.16±0.38 | 0.04±0.06 | -0.03±0.02 | 0.05±0.23 | 0.03±0.05 | 0.04±0.04 | 0.04±0.04 | 0.27±0.31 | 0.07±0.14 | -0.02±0.01 | 0.02±0.18 | 0.08±0.21 | 23 | 24 | 25 | 26 | ### All Results of Prompts 27 | 28 | | Model | SST-2 | CoLA | QQP | MPRC | MNLI | QNLI | RTE | WNLI | MMLU | SQuAD v2 | IWSLT | UN Multi | Math | Avg | 29 | |:-------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:| 30 | | ZS-task | 0.31±0.39 | 0.43±0.35 | 0.43±0.42 | 0.44±0.44 | 0.29±0.35 | 0.46±0.39 | 0.33±0.39 | 0.36±0.36 | 0.25±0.23 | 0.16±0.26 | 0.18±0.22 | 0.17±0.18 | 0.33±0.26 | 0.33±0.36 | 31 | | ZS-role | 0.28±0.35 | 0.43±0.38 | 0.34±0.43 | 0.51±0.43 | 0.26±0.33 | 0.51±0.40 | 0.35±0.40 | 0.39±0.39 | 0.22±0.26 | 0.20±0.28 | 0.24±0.25 | 0.15±0.16 | 0.39±0.30 | 0.34±0.37 | 32 | | FS-task | 0.22±0.38 | 0.24±0.28 | 0.16±0.21 | 0.24±0.32 | 0.19±0.29 | 0.30±0.34 | 0.31±0.39 | 0.37±0.41 | 0.18±0.23 | 0.06±0.11 | 0.08±0.09 | 0.04±0.07 | 0.16±0.18 | 0.21±0.31 | 33 | | FS-role | 0.24±0.39 | 0.25±0.36 | 0.14±0.20 | 0.23±0.30 | 0.21±0.33 | 0.32±0.36 | 0.27±0.38 | 0.33±0.38 | 0.14±0.20 | 0.07±0.12 | 0.11±0.10 | 0.04±0.07 | 0.17±0.17 | 0.21±0.31 | 34 | 35 | 36 | 37 | ### Attack Results View by Models 38 | 39 | | Model | TextBugger | DeepWordBug | TextFoller | BertAttack | CheckList | StressTest | Semantic | 40 | |:--------:|:----------:|:-----------:|:----------:|:----------:|:----------:|:----------:|:---------:| 41 | | T5-Large | 0.09±0.10 | 0.13±0.18 | 0.20±0.24 | 0.21±0.24 | 0.04±0.08 | 0.18±0.24 | 0.10±0.09 | 42 | | Vicuna | 0.81±0.25 | 0.69±0.30 | 0.80±0.26 | 0.84±0.23 | 0.64±0.27 | 0.29±0.40 | 0.74±0.25 | 43 | | LLaMA2 | 0.67±0.36 | 0.41±0.34 | 0.68±0.36 | 0.74±0.33 | 0.34±0.33 | 0.20±0.30 | 0.66±0.35 | 44 | | UL2 | 0.04±0.06 | 0.03±0.04 | 0.14±0.20 | 0.16±0.22 | 0.04±0.07 | 0.06±0.09 | 0.06±0.08 | 45 | | ChatGPT | 0.14±0.20 | 0.08±0.13 | 0.32±0.35 | 0.34±0.34 | 0.07±0.13 | 0.06±0.12 | 0.26±0.22 | 46 | | GPT-4 | 0.03±0.10 | 0.02±0.08 | 0.18±0.19 | 0.27±0.40 | -0.02±0.09 | 0.03±0.15 | 0.03±0.16 | 47 | | Avg | 0.21±0.30 | 0.16±0.26 | 0.31±0.33 | 0.33±0.34 | 0.12±0.23 | 0.11±0.23 | 0.22±0.26 | 48 | 49 | 50 | 51 | ### Attack Results View by Datasets 52 | 53 | 54 | | Model | TextBugger | DeepWordBug | TextFoller | BertAttack | CheckList | StressTest | Semantic | 55 | |:--------:|:----------:|:-----------:|:----------:|:----------:|:---------:|:----------:|:---------:| 56 | | SST-2 | 0.25±0.39 | 0.18±0.33 | 0.35±0.41 | 0.34±0.44 | 0.22±0.36 | 0.15±0.31 | 0.28±0.35 | 57 | | CoLA | 0.39±0.40 | 0.27±0.32 | 0.43±0.35 | 0.45±0.38 | 0.23±0.30 | 0.18±0.25 | 0.34±0.37 | 58 | | QQP | 0.30±0.38 | 0.22±0.31 | 0.31±0.36 | 0.33±0.38 | 0.18±0.30 | 0.06±0.25 | 0.40±0.39 | 59 | | MPRC | 0.37±0.42 | 0.34±0.41 | 0.37±0.41 | 0.42±0.38 | 0.24±0.37 | 0.25±0.33 | 0.39±0.39 | 60 | | MNLI | 0.32±0.40 | 0.18±0.29 | 0.32±0.39 | 0.34±0.36 | 0.14±0.24 | 0.10±0.25 | 0.22±0.24 | 61 | | QNLI | 0.38±0.39 | 0.40±0.35 | 0.50±0.39 | 0.52±0.38 | 0.25±0.39 | 0.23±0.33 | 0.40±0.35 | 62 | | RTE | 0.33±0.41 | 0.25±0.35 | 0.37±0.44 | 0.40±0.42 | 0.18±0.32 | 0.17±0.24 | 0.42±0.40 | 63 | | WNLI | 0.39±0.42 | 0.31±0.37 | 0.41±0.43 | 0.41±0.40 | 0.24±0.32 | 0.20±0.27 | 0.49±0.39 | 64 | | MMLU | 0.21±0.24 | 0.12±0.16 | 0.21±0.20 | 0.40±0.30 | 0.13±0.18 | 0.03±0.15 | 0.20±0.19 | 65 | | SQuAD V2 | 0.09±0.17 | 0.05±0.08 | 0.25±0.29 | 0.31±0.32 | 0.02±0.03 | 0.02±0.04 | 0.08±0.09 | 66 | | IWSLT | 0.08±0.14 | 0.10±0.12 | 0.27±0.30 | 0.12±0.18 | 0.10±0.10 | 0.17±0.19 | 0.18±0.14 | 67 | | UN Multi | 0.06±0.08 | 0.08±0.12 | 0.15±0.19 | 0.10±0.16 | 0.06±0.07 | 0.09±0.11 | 0.15±0.18 | 68 | | Math | 0.18±0.17 | 0.14±0.13 | 0.49±0.36 | 0.42±0.32 | 0.15±0.11 | 0.13±0.08 | 0.23±0.13 | 69 | | Avg | 0.21±0.30 | 0.17±0.26 | 0.31±0.33 | 0.33±0.34 | 0.12±0.23 | 0.11±0.23 | 0.22±0.26 | 70 | 71 | -------------------------------------------------------------------------------- /docs/leaderboard/dyval.md: -------------------------------------------------------------------------------- 1 | # Dynamic Evaluation Benchmark 2 | 3 | DyVal is a new dynamic evaluation protocol for LLMs. More information can be found at [DyVal: Graph-informed Dynamic Evaluation of Large Language Models](https://arxiv.org/abs/2309.17167). 4 | 5 | Please contact us if you want the results of your models shown in this leaderboard. 6 | 7 | [[All results](#all-results)] [[View by Complexity](#view-by-complexity)] 8 | 9 | 10 | ### All results 11 | 12 | | Model | Arithmetic | Linear Equation | Boolean Logic | Deductive Logic | Abductive Logic | Reachability | Max Sum Path | 13 | |:---------------:|:----------:|:---------------:|:-------------:|:---------------:|:---------------:|:------------:|:------------:| 14 | | Vicuna-13B v1.3 | 0.79 | - | 50.76 | 37.04 | 21.10 | 21.58 | - | 15 | | LLaMA2-13B Chat | 8.33 | - | 16.15 | 35.72 | 7.73 | 28.05 | - | 16 | | ChatGPT | 84.50 | 26.63 | 97.34 | 66.56 | 52.49 | 56.09 | 13.63 | 17 | | GPT4 | 89.88 | 45.03 | 99.33 | 93.92 | 66.33 | 79.02 | 23.36 | 18 | 19 | ### View by Complexity 20 | 21 | 22 | #### Complexity 1 23 | 24 | | Model | Arithmetic | Linear Equation | Boolean Logic | Deductive Logic | Abductive Logic | Reachability | Max Sum Path | 25 | |:---------------:|:----------:|:---------------:|:-------------:|:---------------:|:---------------:|:------------:|:------------:| 26 | | Vicuna-13B v1.3 | 1.89 | - | 81.33 | 25.73 | 44.51 | 21.60 | - | 27 | | LLaMA2-13B Chat | 25.07 | - | 19.20 | 50.27 | 1.82 | 27.62 | - | 28 | | ChatGPT | 95.27 | 36.22 | 99.09 | 81.96 | 41.78 | 62.27 | 28.14 | 29 | | GPT4 | 99.00 | 57.05 | 100.00 | 94.45 | 89.29 | 87.22 | 31.56 | 30 | 31 | 32 | #### Complexity 2 33 | 34 | | Model | Arithmetic | Linear Equation | Boolean Logic | Deductive Logic | Abductive Logic | Reachability | Max Sum Path | 35 | |:---------------:|:----------:|:---------------:|:-------------:|:---------------:|:---------------:|:------------:|:------------:| 36 | | Vicuna-13B v1.3 | 0.73 | - | 55.11 | 43.87 | 22.42 | 21.84 | - | 37 | | LLaMA2-13B Chat | 4.44 | - | 14.51 | 40.38 | 16.56 | 29.25 | - | 38 | | ChatGPT | 91.60 | 29.39 | 98.33 | 64.75 | 56.62 | 54.84 | 12.95 | 39 | | GPT4 | 95.11 | 42.61 | 99.78 | 96.06 | 63.61 | 86.33 | 30.45 | 40 | 41 | 42 | #### Complexity 3 43 | 44 | | Model | Arithmetic | Linear Equation | Boolean Logic | Deductive Logic | Abductive Logic | Reachability | Max Sum Path | 45 | |:---------------:|:----------:|:---------------:|:-------------:|:---------------:|:---------------:|:------------:|:------------:| 46 | | Vicuna-13B v1.3 | 0.47 | - | 37.02 | 42.77 | 17.47 | 21.69 | - | 47 | | LLaMA2-13B Chat | 2.20 | - | 17.18 | 28.78 | 9.42 | 27.38 | - | 48 | | ChatGPT | 77.62 | 24.31 | 96.84 | 62.80 | 58.27 | 53.64 | 7.47 | 49 | | GPT4 | 85.95 | 43.78 | 99.00 | 94.78 | 57.67 | 71.17 | 18.33 | 50 | 51 | 52 | #### Complexity 4 53 | 54 | | Model | Arithmetic | Linear Equation | Boolean Logic | Deductive Logic | Abductive Logic | Reachability | Max Sum Path | 55 | |:---------------:|:----------:|:---------------:|:-------------:|:---------------:|:---------------:|:------------:|:------------:| 56 | | Vicuna-13B v1.3 | 0.09 | - | 29.58 | 36.29 | 0.0 | 21.18 | - | 57 | | LLaMA2-13B Chat | 1.60 | - | 13.71 | 23.47 | 3.13 | 27.96 | - | 58 | | ChatGPT | 71.51 | 16.60 | 95.11 | 56.73 | 53.29 | 53.62 | 5.98 | 59 | | GPT4 | 79.44 | 36.67 | 98.56 | 90.39 | 54.78 | 71.33 | 13.11 | 60 | 61 | -------------------------------------------------------------------------------- /docs/leaderboard/pe.md: -------------------------------------------------------------------------------- 1 | # Prompt Engineering Benchmark 2 | 3 | The Prompt Engineering Module collects a variety of prompting methods and evaluates their performance across multiple datasets. This module currently supports models including GPT-3.5-turbo and GPT-4-1106. 4 | 5 | Please contact us if you want the results of your models shown in this leaderboard. 6 | 7 | ### All Results 8 | 9 | | Model | benchmark | baseline | CoT | CoT(zero-shot) | expert prompting | emotion prompt | least to most | 10 | |-----------------|-------------------------|----------|-------|----------------|------------------|----------------|---------------| 11 | | GPT3.5 -Turbo | gsm8k | 47.15 | 40.33 | 18.5 | 21.15 | 57.24 | | 12 | | GPT3.5 -Turbo | bigbench_date | 57.99 | 49.32 | 80.49 | 61.79 | 66.12 | | 13 | | GPT3.5 -Turbo | bigbench_object_tracking| 39.2 | 63.2 | 66 | 56.53 | 29.87 | | 14 | | GPT3.5 -Turbo | csqa | 72.48 | 67.81 | 65.85 | 74.45 | 70.68 | | 15 | | GPT3.5 -Turbo | last-letter-concat | 7.2 | | | | | 79.8 | 16 | | GPT4-1106 | gsm8k | 92.19 | 85.89 | 87.34 | 88.7 | 90.83 | | 17 | | GPT4-1106 | bigbench_date | 87.8 | 92.14 | 87.53 | 87.26 | 87.8 | | 18 | | GPT4-1106 | bigbench_object_tracking| 96.27 | 90.26 | 99.07 | 98.93 | 95.73 | | 19 | | GPT4-1106 | csqa | 79.69 | 85.59 | 79.85 | 79.85 | 80.34 | | 20 | | GPT4-1106 | last-letter-concat | 25.2 | | | | | 96.2 | 21 | 22 | *"This is very important to my career." is used in emotion prompt* 23 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/reference/dataload/dataload.rst: -------------------------------------------------------------------------------- 1 | dataload.dataload 2 | ============================= 3 | 4 | .. automodule:: promptbench.dataload.dataload 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/dataload/dataset.rst: -------------------------------------------------------------------------------- 1 | dataload.dataset 2 | ============================ 3 | 4 | .. automodule:: promptbench.dataload.dataset 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/dataload/index.rst: -------------------------------------------------------------------------------- 1 | dataload 2 | ==================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | dataload 8 | dataset 9 | -------------------------------------------------------------------------------- /docs/reference/dyval/DAG/code_dag.rst: -------------------------------------------------------------------------------- 1 | dyval.DAG.code_dag 2 | ============================== 3 | 4 | .. automodule:: promptbench.dyval.DAG.code_dag 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/dyval/DAG/dag.rst: -------------------------------------------------------------------------------- 1 | dyval.DAG.dag 2 | ========================= 3 | 4 | .. automodule:: promptbench.dyval.DAG.dag 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/dyval/DAG/describer.rst: -------------------------------------------------------------------------------- 1 | dyval.DAG.describer 2 | =============================== 3 | 4 | .. automodule:: promptbench.dyval.DAG.describer 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/dyval/DAG/index.rst: -------------------------------------------------------------------------------- 1 | dyval.DAG 2 | ===================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | code_dag 8 | dag 9 | describer 10 | logic_dag 11 | math_dag 12 | -------------------------------------------------------------------------------- /docs/reference/dyval/DAG/logic_dag.rst: -------------------------------------------------------------------------------- 1 | dyval.DAG.logic_dag 2 | =============================== 3 | 4 | .. automodule:: promptbench.dyval.DAG.logic_dag 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/dyval/DAG/math_dag.rst: -------------------------------------------------------------------------------- 1 | dyval.DAG.math_dag 2 | ============================== 3 | 4 | .. automodule:: promptbench.dyval.DAG.math_dag 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/dyval/dyval_dataset.rst: -------------------------------------------------------------------------------- 1 | dyval.dyval_dataset 2 | =============================== 3 | 4 | .. automodule:: promptbench.dyval.dyval_dataset 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/dyval/dyval_utils.rst: -------------------------------------------------------------------------------- 1 | dyval.dyval_utils 2 | ============================= 3 | 4 | .. automodule:: promptbench.dyval.dyval_utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/dyval/index.rst: -------------------------------------------------------------------------------- 1 | dyval 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | dyval_dataset 8 | dyval_utils 9 | DAG/index -------------------------------------------------------------------------------- /docs/reference/index.rst: -------------------------------------------------------------------------------- 1 | Reference 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | 7 | dataload/index 8 | dyval/index 9 | metrics/index 10 | models/index 11 | utils/index -------------------------------------------------------------------------------- /docs/reference/metrics/eval.rst: -------------------------------------------------------------------------------- 1 | metrics.eval 2 | ======================== 3 | 4 | .. automodule:: promptbench.metrics.eval 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/metrics/index.rst: -------------------------------------------------------------------------------- 1 | metrics 2 | =================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | eval -------------------------------------------------------------------------------- /docs/reference/models/index.rst: -------------------------------------------------------------------------------- 1 | models 2 | ====== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | models -------------------------------------------------------------------------------- /docs/reference/models/models.rst: -------------------------------------------------------------------------------- 1 | models.models 2 | ========================= 3 | 4 | .. automodule:: promptbench.models.models 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/utils/dataprocess.rst: -------------------------------------------------------------------------------- 1 | promptbench.utils.dataprocess 2 | ============================= 3 | 4 | .. automodule:: promptbench.utils.dataprocess 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/reference/utils/index.rst: -------------------------------------------------------------------------------- 1 | utils 2 | ===== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | dataprocess -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | autocorrect==2.6.1 2 | accelerate==0.25.0 3 | datasets==2.15.0 4 | myst-parser 5 | nltk==3.8.1 6 | openai==1.3.7 7 | sentencepiece==0.1.99 8 | tokenizers==0.15.0 9 | torch==2.2.0 10 | tqdm==4.66.3 11 | transformers==4.38.0 12 | Sphinx==7.1.2 13 | sphinx-autodoc-typehints==1.25.2 14 | sphinx-markdown-builder==0.6.5 15 | sphinx-rtd-theme==1.3.0 16 | sphinxcontrib-applehelp==1.0.4 17 | sphinxcontrib-devhelp==1.0.2 18 | sphinxcontrib-htmlhelp==2.0.1 19 | sphinxcontrib-jquery==4.1 20 | sphinxcontrib-jsmath==1.0.1 21 | sphinxcontrib-qthelp==1.0.3 22 | sphinxcontrib-serializinghtml==1.1.5 -------------------------------------------------------------------------------- /docs/start/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | There are two ways to install promptbench. If you want to simply use as it is, install via `pip`. If you want to make any changes and play around, install it from source. 4 | 5 | We recommend to build virtual environment via [anaconda/miniconda](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#activating-an-environment) or [python virtual environment](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/) to better manage your python library. 6 | 7 | 8 | 9 | ## Install via `pip` 10 | We provide a Python package *promptbench* for users who want to start evaluation quickly. Simply run 11 | ```sh 12 | pip install promptbench 13 | ``` 14 | 15 | 16 | ## Install via github 17 | 18 | First, clone the repo: 19 | ```sh 20 | git clone git@github.com:microsoft/promptbench.git 21 | ``` 22 | 23 | Then, 24 | 25 | ```sh 26 | cd promptbench 27 | ``` 28 | 29 | To install the required packages, you can create a conda environment: 30 | 31 | ```sh 32 | conda create --name promptbench python=3.9 33 | ``` 34 | 35 | then use pip to install required packages: 36 | 37 | ```sh 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | Note that this only installed basic python packages. For Prompt Attacks, it requires to install textattacks. 42 | -------------------------------------------------------------------------------- /docs/start/intro.md: -------------------------------------------------------------------------------- 1 | # promptbench Introduction 2 | **PromptBench** is a unified library for evaluating and understanding large language models. 3 | 4 | 5 | ## What does promptbench currently contain? 6 | 1. **Quick access your model performance:** We provide a user-friendly interface for quick build models, load dataset, and evaluate model performance. 7 | 2. **Prompt Engineering:** 8 | 3. **Evaluating adversarial prompts:** promptbench integrated [prompt attacks](https://arxiv.org/abs/2306.04528) [1] for researchers simulate black-box adversarial prompt attacks on the models and evaluate their performances. 9 | 4. **Dynamic evaluation to mitigate potential test data contamination:** we integrated the dynamic evaluation framework DyVal [2], which generates evaluation samples on-the-fly with controlled complexity. 10 | 11 | 12 | ## Where should I get started? 13 | If you want to 14 | 1. **evaluate my model on existing benchmarks:** please refer to the `examples/basic.ipynb` for constructing your evaluation pipeline. For a multi-modal evaluation pipeline, please refer to `examples/multimodal.ipynb`. 15 | 2. **test the effects of different prompting techniques:** 16 | 3. **examine the robustness for prompt attacks**, please refer to `examples/prompt_attack.ipynb` to construct the attacks. 17 | 4. **use DyVal for evaluation:** please refer to `examples/dyval.ipynb` to construct DyVal datasets. 18 | 19 | -------------------------------------------------------------------------------- /examples/add_new_modules.md: -------------------------------------------------------------------------------- 1 | Each module in promptbench can be easily extended. In the following, we provide basic guidelines for customizing your own datasets, models, prompt engineering methods, and evaluation metrics. 2 | 3 | ## Add new datasets 4 | Adding new datasets involves two steps: 5 | 6 | - Implementing a New Dataset Class: Datasets are supposed to be implemented in `dataload/dataset.py` and inherit from the `Dataset` class. For your custom dataset, implement the `__init__` method to load your dataset. We recommend organizing your data samples as dictionaries to facilitate the input process. 7 | - Adding an Interface: After customizing the dataset class, register it in the `DataLoader` class within `dataload.py`. 8 | 9 | 10 | 11 | ## Add new models 12 | Similar to adding new datasets, the addition of new models also consists of two steps. 13 | - Implementing a New Model Class: Models should be implemented in `models/model.py`, inheriting from the `LLMModel` class. In your customized model, you should implement `self.tokenizer` and `self.model`. You may also customize your own `predict` function for inference. If the `predict` function is not customized, the default `predict` function inherited from `LLMModel` will be used. 14 | - Adding an Interface: After customizing the model class, register it in the `_create_model` function within the `class LLMModel` and `MODEL_LIST` dictionary in `__init__.py`. 15 | 16 | 17 | 18 | ## Add new prompt engineering methods 19 | Adding new methods in prompt engineering is similar to steps of adding new datasets and models. 20 | 21 | - Implementing a New Methods Class: Methods should be implemented in \\ `prompt\_engineering` Module. Firstly, create a new `.py` file for your methods. 22 | Then implement two functions: `\_\_init\_\_` and `query`. For unified management, two points need be noticed: 1. all methods should inherits from `Base` class that has common code for prompt engineering methods. 2. prompts used in methods should be stored in `prompts/method\_oriented.py`. 23 | - Adding an Interface: After implementing a new methods, register it in the `METHOD\_MAP` that is used to map method names to their corresponding class. 24 | 25 | 26 | ## Add new metrics and input/output process functions 27 | New evaluation metrics should be implemented as static functions in `class Eval` within the `metrics` module. Similarly, new input/output process functions should be implemented as static functions in `class InputProcess` and `class OutputProcess` in the `utils` module. 28 | -------------------------------------------------------------------------------- /examples/prompt_attack.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# First install textattack, tensorflow, and tensorflow_hub\n", 10 | "!pip install textattack tensorflow tensorflow_hub" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "name": "stderr", 20 | "output_type": "stream", 21 | "text": [ 22 | "/home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 23 | " from .autonotebook import tqdm as notebook_tqdm\n", 24 | "2023-12-24 03:45:05.172891: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", 25 | "2023-12-24 03:45:05.172945: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", 26 | "2023-12-24 03:45:05.173987: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", 27 | "2023-12-24 03:45:05.180461: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 28 | "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 29 | "2023-12-24 03:45:06.000286: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import promptbench as pb\n", 35 | "from promptbench.models import LLMModel\n", 36 | "from promptbench.prompt_attack import Attack" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stderr", 46 | "output_type": "stream", 47 | "text": [ 48 | "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n", 49 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" 50 | ] 51 | }, 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "['textbugger', 'deepwordbug', 'textfooler', 'bertattack', 'checklist', 'stresstest', 'semantic']\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "# create model\n", 62 | "model_t5 = LLMModel(model='google/flan-t5-large')\n", 63 | "\n", 64 | "# create dataset\n", 65 | "dataset = pb.DatasetLoader.load_dataset(\"sst2\")\n", 66 | "\n", 67 | "# try part of the dataset\n", 68 | "dataset = dataset[:10]\n", 69 | "\n", 70 | "# create prompt\n", 71 | "prompt = \"As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \\nQuestion: {content}\\nAnswer:\"\n", 72 | "\n", 73 | "# define the projection function required by the output process\n", 74 | "def proj_func(pred):\n", 75 | " mapping = {\n", 76 | " \"positive\": 1,\n", 77 | " \"negative\": 0\n", 78 | " }\n", 79 | " return mapping.get(pred, -1)\n", 80 | "\n", 81 | "# define the evaluation function required by the attack\n", 82 | "# if the prompt does not require any dataset, for example, \"write a poem\", you still need to include the dataset parameter\n", 83 | "def eval_func(prompt, dataset, model):\n", 84 | " preds = []\n", 85 | " labels = []\n", 86 | " for d in dataset:\n", 87 | " input_text = pb.InputProcess.basic_format(prompt, d)\n", 88 | " raw_output = model(input_text)\n", 89 | "\n", 90 | " output = pb.OutputProcess.cls(raw_output, proj_func)\n", 91 | " preds.append(output)\n", 92 | "\n", 93 | " labels.append(d[\"label\"])\n", 94 | " \n", 95 | " return pb.Eval.compute_cls_accuracy(preds, labels)\n", 96 | " \n", 97 | "# define the unmodifiable words in the prompt\n", 98 | "# for example, the labels \"positive\" and \"negative\" are unmodifiable, and \"content\" is modifiable because it is a placeholder\n", 99 | "# if your labels are enclosed with '', you need to add \\' to the unmodifiable words (due to one feature of textattack)\n", 100 | "unmodifiable_words = [\"positive\\'\", \"negative\\'\", \"content\"]\n", 101 | "\n", 102 | "# print all supported attacks\n", 103 | "print(Attack.attack_list())" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 3, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "These words (if they appear in the prompt) are not allowed to be attacked:\n", 116 | "[\"positive'\", \"negative'\", 'content']\n" 117 | ] 118 | }, 119 | { 120 | "name": "stderr", 121 | "output_type": "stream", 122 | "text": [ 123 | "/home/v-zhukaijie/anaconda3/envs/promptbench/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", 124 | " warnings.warn(\n" 125 | ] 126 | }, 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "--------------------------------------------------\n", 132 | "Current prompt is: As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \n", 133 | "Question: {content}\n", 134 | "Answer:\n", 135 | "Current accuracy is: 1.0\n", 136 | "--------------------------------------------------\n", 137 | "\n", 138 | "--------------------------------------------------\n", 139 | "Modifiable words: ['As', 'a', 'sentiment', 'classifier', 'determine', 'whether', 'the', 'following', 'text', 'is', 'or', 'Please', 'classify', 'Question', 'Answer']\n", 140 | "--------------------------------------------------\n", 141 | "\n", 142 | "--------------------------------------------------\n", 143 | "Current prompt is: As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \n", 144 | "Question: {content}\n", 145 | "Answer and false is not true :\n", 146 | "Current accuracy is: 1.0\n", 147 | "--------------------------------------------------\n", 148 | "\n", 149 | "--------------------------------------------------\n", 150 | "Current prompt is: As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \n", 151 | "Question: {content}\n", 152 | "Answer and true is true and true is true and true is true and true is true and true is true :\n", 153 | "Current accuracy is: 1.0\n", 154 | "--------------------------------------------------\n", 155 | "\n", 156 | "--------------------------------------------------\n", 157 | "Current prompt is: As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \n", 158 | "Question: {content}\n", 159 | "Answer and true is true :\n", 160 | "Current accuracy is: 1.0\n", 161 | "--------------------------------------------------\n", 162 | "\n", 163 | "{'original prompt': \"As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \\nQuestion: {content}\\nAnswer:\", 'original score': 1.0, 'attacked prompt': \"As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: \\nQuestion: {content}\\nAnswer and false is not true :\", 'attacked score': 1.0, 'PDR': 0.0}\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "# create attack, specify the model, dataset, prompt, evaluation function, and unmodifiable words\n", 169 | "# verbose=True means that the attack will print the intermediate results\n", 170 | "attack = Attack(model_t5, \"stresstest\", dataset, prompt, eval_func, unmodifiable_words, verbose=True)\n", 171 | "\n", 172 | "# print attack result\n", 173 | "print(attack.attack())" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "promptbench", 187 | "language": "python", 188 | "name": "python3" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.9.18" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 2 205 | } 206 | -------------------------------------------------------------------------------- /imgs/prompt_attack_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/promptbench/fcda538bd779ad11612818e0645a387a462b5c3b/imgs/prompt_attack_attention.png -------------------------------------------------------------------------------- /imgs/promptbench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/promptbench/fcda538bd779ad11612818e0645a387a462b5c3b/imgs/promptbench.png -------------------------------------------------------------------------------- /imgs/promptbench_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/promptbench/fcda538bd779ad11612818e0645a387a462b5c3b/imgs/promptbench_logo.png -------------------------------------------------------------------------------- /promptbench/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .models import LLMModel, VLMModel, SUPPORTED_MODELS, SUPPORTED_MODELS_VLM 5 | from .prompt_engineering import PEMethod, SUPPORTED_METHODS, METHOD_SUPPORT_DATASET 6 | from .dataload import DatasetLoader, SUPPORTED_DATASETS, SUPPORTED_DATASETS_VLM 7 | from .prompts import Prompt 8 | from .utils import InputProcess, OutputProcess 9 | from .metrics import Eval 10 | from .dyval import DyValDataset -------------------------------------------------------------------------------- /promptbench/dataload/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .dataset import * 5 | from .dataload import * -------------------------------------------------------------------------------- /promptbench/dataload/dataload.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .dataset import * 5 | 6 | SUPPORTED_DATASETS = [ 7 | "sst2", "cola", "qqp", 8 | "mnli", "mnli_matched", "mnli_mismatched", 9 | "qnli", "wnli", "rte", "mrpc", 10 | "mmlu", "squad_v2", "un_multi", "iwslt2017", "math", 11 | "bool_logic", "valid_parentheses", 12 | "gsm8k", "csqa", "bigbench_date", "bigbench_object_tracking", 13 | "last_letter_concat", "numersense", "qasc", 14 | "bbh", "drop", "arc-easy", "arc-challenge", 15 | ] 16 | 17 | SUPPORTED_DATASETS_VLM = [ 18 | "vqav2", "nocaps", "science_qa", 19 | "math_vista", "ai2d", "mmmu", "chart_qa" 20 | ] 21 | 22 | class DatasetLoader: 23 | 24 | @staticmethod 25 | def load_dataset(dataset_name, task=None, supported_languages=None): 26 | """ 27 | Load and return the specified dataset. 28 | 29 | This function acts as a factory method, returning the appropriate dataset object 30 | based on the provided dataset name. 31 | 'math', 'un_multi' and 'iwslt' require additional arguments to specify the languages used in the dataset. 32 | 33 | Args: 34 | dataset_name (str): The name of the dataset to load. 35 | task: str: Additional arguments required by 'math'. 36 | Please visit https://huggingface.co/datasets/math_dataset/ to see the supported tasks for math. 37 | supported_languages: list: Additional arguments required by 'iwslt'. 38 | Please visit https://huggingface.co/datasets/iwslt2017 to see the supported languages for iwslt. 39 | e.g. supported_languages=['de-en', 'ar-en'] for German-English and Arabic-English translation. 40 | Returns: 41 | Dataset object corresponding to the given dataset_name. 42 | The dataset object is an instance of a list, each element is a dictionary. Please refer to each dataset's documentation for details. 43 | 44 | Raises: 45 | NotImplementedError: If the dataset_name does not correspond to any known dataset. 46 | """ 47 | # GLUE datasets 48 | if dataset_name in ["cola", "sst2", "qqp", "mnli", "mnli_matched", "mnli_mismatched", 49 | "qnli", "wnli", "rte", "mrpc"]: 50 | return GLUE(dataset_name) 51 | elif dataset_name == 'mmlu': 52 | return MMLU() 53 | elif dataset_name == "squad_v2": 54 | return SQUAD_V2() 55 | elif dataset_name == 'un_multi': 56 | return UnMulti() 57 | elif dataset_name == 'iwslt2017': 58 | return IWSLT(supported_languages) 59 | elif dataset_name in 'math': 60 | return Math(task) 61 | elif dataset_name == 'bool_logic': 62 | return BoolLogic() 63 | elif dataset_name == 'valid_parentheses': 64 | return ValidParentheses() 65 | elif dataset_name == 'gsm8k': 66 | return GSM8K() 67 | elif dataset_name == 'csqa': 68 | return CSQA() 69 | elif dataset_name == 'qasc': 70 | return QASC() 71 | elif 'bigbench' in dataset_name: 72 | return BigBench(dataset_name) 73 | elif dataset_name == 'bbh': 74 | return BBH() 75 | elif dataset_name == 'drop': 76 | return DROP() 77 | elif dataset_name == 'arc-easy': 78 | return ARC('ARC-Easy') 79 | elif dataset_name == 'arc-challenge': 80 | return ARC('ARC-Challenge') 81 | elif dataset_name == 'vqav2': 82 | return VQAv2() 83 | elif dataset_name =='nocaps': 84 | return NoCaps() 85 | elif dataset_name =='math_vista': 86 | return MathVista() 87 | elif dataset_name == 'ai2d': 88 | return AI2D() 89 | elif dataset_name == 'mmmu': 90 | return MMMU() 91 | elif dataset_name == 'chart_qa': 92 | return ChartQA() 93 | elif dataset_name == 'science_qa': 94 | return ScienceQA() 95 | else: 96 | # If the dataset name doesn't match any known datasets, raise an error 97 | raise NotImplementedError(f"Dataset '{dataset_name}' is not supported.") 98 | 99 | -------------------------------------------------------------------------------- /promptbench/dyval/DAG/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/promptbench/fcda538bd779ad11612818e0645a387a462b5c3b/promptbench/dyval/DAG/__init__.py -------------------------------------------------------------------------------- /promptbench/dyval/DAG/code_dag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import random 5 | from .dag import GeneralDAG 6 | from .describer import GeneralDAGDescriber 7 | 8 | class CodeDAG(GeneralDAG): 9 | """ 10 | A specialized Directed Acyclic Graph (DAG) for representing and analyzing code structures. 11 | 12 | This class extends GeneralDAG and provides additional functionality specific to code analysis, 13 | including reachability and maximum sum path calculations. 14 | 15 | Parameters: 16 | ----------- 17 | num_nodes : int 18 | The number of nodes in the DAG. 19 | min_links_per_node : int, optional 20 | The minimum number of links per node (default is 1). 21 | max_links_per_node : int, optional 22 | The maximum number of links per node (default is 3). 23 | 24 | Methods: 25 | -------- 26 | reachability(start, end) 27 | Determines whether there is a path from the start node to the end node. 28 | max_sum_path(start, end) 29 | Finds the path from start to end having the maximum sum of node values. 30 | """ 31 | def __init__(self, num_nodes, min_links_per_node=1, max_links_per_node=3): 32 | super().__init__(num_nodes, min_links_per_node, max_links_per_node, add_cycles=0) 33 | 34 | def reachability(self, start, end): 35 | descriptions = [] 36 | visited = set() 37 | stack = [start] 38 | 39 | descriptions.append(f"Starting the search process from node {start.name} with the goal to reach node {end.name}.") 40 | 41 | while stack: 42 | node = stack.pop() 43 | 44 | descriptions.append(f"Checking node {node.name}.") 45 | 46 | if node == end: 47 | descriptions.append(f"Successfully reached the target node {end.name}.") 48 | return True, "\n".join(descriptions) 49 | if node not in visited: 50 | visited.add(node) 51 | unvisited_children = [child for child in node.children if child not in visited] 52 | if unvisited_children: 53 | child_names = ', '.join(child.name for child in unvisited_children) 54 | descriptions.append(f"Exploring children of node {node.name}: {child_names}.") 55 | else: 56 | descriptions.append(f"Node {node.name} has no unvisited children. Moving back.") 57 | stack.extend(unvisited_children) 58 | else: 59 | descriptions.append(f"Node {node.name} has already been visited. Skipping.") 60 | 61 | descriptions.append(f"Exhausted all possible paths without reaching node {end.name}.") 62 | return False, "\n".join(descriptions) 63 | 64 | def max_sum_path(self, start, end): 65 | descriptions = [] # This will hold our natural language descriptions. 66 | queue = [(start, start.value)] # Each element in the queue is a tuple (node, current_sum) 67 | max_sum = float('-inf') 68 | 69 | descriptions.append(f"Starting the search for the maximum sum path from node {start.name} to node {end.name}.") 70 | 71 | while queue: 72 | node, cur_sum = queue.pop(0) 73 | 74 | descriptions.append(f"Reaching node {node.name} with current sum of {cur_sum}.") 75 | 76 | if node == end: 77 | if cur_sum > max_sum: 78 | max_sum = max(max_sum, cur_sum) 79 | descriptions.append(f"Found a path to node {end.name} with a new maximum sum of {cur_sum}.") 80 | else: 81 | descriptions.append(f"Found a path to node {end.name} with a sum of {cur_sum}, which is less than current maximum sum.") 82 | 83 | else: 84 | if len(node.children) == 0: 85 | descriptions.append(f"Node {node.name} has no children. Moving back.") 86 | else: 87 | child_names = ', '.join(child.name for child in node.children) 88 | descriptions.append(f"Now, we explore the children of node {node.name}: {child_names}.") 89 | 90 | for child in node.children: 91 | queue.append((child, cur_sum + child.value)) 92 | 93 | if max_sum != float('-inf'): 94 | descriptions.append(f"The maximum sum from node {start.name} to node {end.name} is {max_sum}.") 95 | else: 96 | descriptions.append(f"There is no path from node {start.name} to node {end.name}.") 97 | 98 | return max_sum if max_sum != float('-inf') else "N/A", "\n".join(descriptions) 99 | 100 | 101 | class CodeDAGDescriber(GeneralDAGDescriber): 102 | """ 103 | Describer class for CodeDAG, providing natural language descriptions of various DAG properties and questions. 104 | 105 | This class extends GeneralDAGDescriber to work specifically with CodeDAG instances. 106 | 107 | Parameters: 108 | ----------- 109 | dag_obj : CodeDAG 110 | The CodeDAG instance to describe. 111 | dataset_type : str 112 | The type of dataset for the DAG (e.g., 'reachability', 'max_sum_path'). 113 | add_rand_desc : int, optional 114 | Random description addition parameter (default is 0). 115 | 116 | Methods: 117 | -------- 118 | describe_reachability() 119 | Generates a natural language description of a reachability question in the DAG. 120 | describe_max_sum_path() 121 | Generates a natural language description of a maximum sum path question in the DAG. 122 | describe_question() 123 | Generates a natural language description based on the dataset type. 124 | describe_answer() 125 | Returns the answer to the described question. 126 | describe_inference_steps() 127 | Returns the natural language inference steps for the described question. 128 | """ 129 | def __init__(self, dag_obj: CodeDAG, dataset_type, add_rand_desc=0): 130 | super().__init__(dag_obj, add_rand_desc, delete_desc=0) 131 | self.dataset_type = dataset_type 132 | 133 | def describe_reachability(self): 134 | start, end = random.sample(self.dag_obj.nodes, 2) 135 | answer, inference_steps = self.dag_obj.reachability(start, end) 136 | return f"Can {end.name} be reached starting from {start.name}?", answer, inference_steps 137 | 138 | def describe_max_sum_path(self): 139 | # Randomly select two distinct nodes 140 | node1, node2 = random.sample(self.dag_obj.nodes, 2) 141 | 142 | if self.dag_obj.sorted_node_names.index(node1.name) < self.dag_obj.sorted_node_names.index(node2.name): 143 | start, end = node2, node1 144 | else: 145 | start, end = node1, node2 146 | 147 | answer, inference_steps = self.dag_obj.max_sum_path(start, end) 148 | 149 | return f"What's the maximum sum path from {start.name} to {end.name}?", answer, inference_steps 150 | 151 | def describe_question(self): 152 | descriptions = self._describe_question() 153 | if self.dataset_type == "reachability": 154 | desc, answer, inference_steps = self.describe_reachability() 155 | 156 | elif self.dataset_type == "max_sum_path": 157 | value_desc = "" 158 | for node in self.dag_obj.nodes: 159 | value_desc += f"The value of {node.name} is {node.value}\n" 160 | desc, answer, inference_steps = self.describe_max_sum_path() 161 | desc = value_desc + desc 162 | 163 | self.inference_steps = inference_steps 164 | 165 | for order, description in descriptions.items(): 166 | descriptions[order] = description + "\n" + desc 167 | 168 | self.answer = answer 169 | return descriptions 170 | 171 | def describe_answer(self): 172 | return self.answer 173 | 174 | def describe_inference_steps(self): 175 | return self.inference_steps -------------------------------------------------------------------------------- /promptbench/dyval/DAG/describer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import random 5 | from .dag import BaseDAG, GeneralDAG, TreeDAG 6 | 7 | class BaseDAGDescriber: 8 | """ 9 | Base class for creating descriptions of Directed Acyclic Graphs (DAGs). 10 | 11 | This class provides methods to traverse and describe DAGs in various orders and to manipulate descriptions. 12 | 13 | Parameters: 14 | ----------- 15 | dag_obj : BaseDAG 16 | The DAG object to be described. 17 | add_rand_desc : int, optional 18 | The number of random descriptions to add. 19 | delete_desc : int, optional 20 | The number of descriptions to delete. 21 | 22 | Methods: 23 | -------- 24 | describe_question_node(node) 25 | Describes a node in the DAG for question formation. 26 | generate_rand_description() 27 | Generates random descriptions based on the DAG type. 28 | topological_traversal(desc_func) 29 | Traverses the DAG in topological order and applies a description function. 30 | reverse_topological_traversal(topo_desc) 31 | Reverses the order of topological descriptions. 32 | random_traversal(topo_desc) 33 | Shuffles the topological descriptions randomly. 34 | add_rand_description(desc) 35 | Adds random descriptions to the existing description list. 36 | delete_description(desc) 37 | Deletes descriptions from the existing description list. 38 | _describe_question() 39 | Describes the DAG for question formation in various traversal orders. 40 | """ 41 | 42 | def __init__(self, dag_obj: BaseDAG, add_rand_desc=0, delete_desc=0): 43 | self.dag_obj = dag_obj 44 | self.add_rand_desc = add_rand_desc 45 | self.delete_desc = delete_desc 46 | 47 | def describe_question_node(self, node): 48 | # Describe a node in the DAG to form a question for test set 49 | # return a string 50 | raise NotImplementedError 51 | 52 | def generate_rand_description(self): 53 | raise NotImplementedError 54 | 55 | def topological_traversal(self, desc_func): 56 | node_names = self.dag_obj.topological_sort() 57 | 58 | descriptions = [] 59 | for node_name in node_names: 60 | for node in self.dag_obj.nodes: 61 | if node.name == node_name: 62 | descriptions.append(desc_func(node)) 63 | return descriptions 64 | 65 | def reverse_topological_traversal(self, topo_desc): 66 | reversed_desc = topo_desc.copy() 67 | reversed_desc.reverse() 68 | return reversed_desc 69 | 70 | def random_traversal(self, topo_desc): 71 | rand_desc = topo_desc.copy() 72 | random.shuffle(rand_desc) 73 | return rand_desc 74 | 75 | def add_rand_description(self, desc): 76 | # Generate a random description, the generation depends on the type of DAG 77 | for _ in range(self.add_rand_desc): 78 | rand_desc = self.generate_rand_description() 79 | for cur_desc in rand_desc: 80 | desc.insert(random.randint(0, len(desc)), cur_desc) 81 | 82 | def delete_description(self, desc): 83 | for _ in range(self.delete_desc): 84 | desc.pop(random.randint(0, len(desc) - 1)) 85 | 86 | def _describe_question(self): 87 | descriptions = {} 88 | 89 | topo_desc = self.topological_traversal(self.describe_question_node) 90 | 91 | self.delete_description(topo_desc) 92 | self.add_rand_description(topo_desc) 93 | 94 | reversed_desc = self.reverse_topological_traversal(topo_desc) 95 | rand_desc = self.random_traversal(topo_desc) 96 | descriptions["topological"] = "\n".join(topo_desc) 97 | descriptions["reversed"] = "\n".join(reversed_desc) 98 | descriptions["random"] = "\n".join(rand_desc) 99 | 100 | return descriptions 101 | 102 | 103 | class GeneralDAGDescriber(BaseDAGDescriber): 104 | """ 105 | A describer class for GeneralDAG instances. 106 | 107 | Inherits from BaseDAGDescriber and provides specific implementations for describing GeneralDAG nodes. 108 | 109 | Parameters: 110 | ----------- 111 | dag_obj : GeneralDAG 112 | The GeneralDAG instance to describe. 113 | add_rand_desc : int, optional 114 | The number of random descriptions to add (inherited). 115 | delete_desc : int, optional 116 | The number of descriptions to delete (inherited). 117 | 118 | Methods: 119 | -------- 120 | describe_question_node(node) 121 | Provides a description for a GeneralDAG node. 122 | generate_rand_description() 123 | Generates random descriptions specific to GeneralDAG. 124 | describe_answer() 125 | Describes the DAG for answer formation (not implemented yet). 126 | """ 127 | 128 | def __init__(self, dag_obj: GeneralDAG, add_rand_desc=0, delete_desc=0): 129 | super().__init__(dag_obj, add_rand_desc, delete_desc) 130 | 131 | def describe_question_node(self, node): 132 | child_names = ", ".join([child.name for child in node.children]) 133 | description = f"{node.name} points to: ({child_names if child_names else 'None'})." 134 | return description 135 | 136 | def generate_rand_description(self): 137 | rand_desc = [] 138 | nodes = self.dag_obj.generate_dag(num_nodes=3) 139 | for node in nodes: 140 | rand_desc.append(self.describe_question_node(node)) 141 | 142 | return rand_desc 143 | 144 | def describe_answer(self): 145 | # Describe the DAG to form a answer for training set 146 | # return a string 147 | raise NotImplementedError 148 | 149 | 150 | class TreeDAGDescriber(BaseDAGDescriber): 151 | """ 152 | A describer class for TreeDAG instances. 153 | 154 | Inherits from BaseDAGDescriber and provides specific implementations for describing TreeDAG nodes. 155 | 156 | Parameters: 157 | ----------- 158 | dag_obj : TreeDAG 159 | The TreeDAG instance to describe. 160 | add_rand_desc : int, optional 161 | The number of random descriptions to add (inherited). 162 | delete_desc : int, optional 163 | The number of descriptions to delete (inherited). 164 | trainset : bool 165 | Indicates if the describer is used for training set generation. 166 | 167 | Methods: 168 | -------- 169 | describe_inference_node(node) 170 | Provides a description for a TreeDAG node for inference. 171 | generate_rand_description() 172 | Generates random descriptions specific to TreeDAG. 173 | describe_inference_steps() 174 | Describes the inference steps based on the DAG's topology. 175 | describe_answer() 176 | Provides the answer based on the root value of the TreeDAG. 177 | describe_question() 178 | Describes the DAG for question formation in various traversal orders. 179 | """ 180 | 181 | def __init__(self, dag_obj: TreeDAG, add_rand_desc=0, delete_desc=0, trainset=False): 182 | self.trainset = trainset 183 | super().__init__(dag_obj, add_rand_desc, delete_desc) 184 | 185 | def describe_inference_node(self, node): 186 | # Describe a node in the DAG to form a answer for training set 187 | # return a string 188 | raise NotImplementedError 189 | 190 | def generate_rand_description(self): 191 | rand_desc = [] 192 | root = self.dag_obj.generate_tree(depth=2) 193 | rand_desc.append(self.describe_question_node(root)) 194 | for child in root.children: 195 | rand_desc.append(self.describe_question_node(child)) 196 | 197 | return rand_desc 198 | 199 | def describe_inference_steps(self): 200 | return "\n".join(self.topological_traversal(self.describe_inference_node)) 201 | 202 | def describe_answer(self): 203 | return self.dag_obj.root.value 204 | 205 | def describe_question(self): 206 | return self._describe_question() -------------------------------------------------------------------------------- /promptbench/dyval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .dyval_dataset import DyValDataset 5 | from .dyval_utils import * 6 | 7 | 8 | DYVAL_DATASETS = [ 9 | "arithmetic", 10 | "linear_equation", 11 | "bool_logic", 12 | "deductive_logic", 13 | "abductive_logic", 14 | "reachability", 15 | "max_sum_path" 16 | ] -------------------------------------------------------------------------------- /promptbench/dyval/dyval_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from tqdm import tqdm 5 | from .dyval_utils import process_dyval_training_sample 6 | 7 | class DyValDataset: 8 | """ 9 | A class for creating and managing datasets for various types of Directed Acyclic Graph (DAG) tasks. 10 | 11 | This class can generate datasets for arithmetic, Boolean logic, linear equations, deductive logic, abductive logic, reachability, and max sum path problems using different types of DAGs. 12 | 13 | Parameters: 14 | ----------- 15 | dataset_type : str 16 | The type of dataset to be generated (e.g., 'arithmetic', 'bool_logic'). 17 | is_trainset : bool, optional 18 | Specifies whether the dataset is a training set (default is False). 19 | num_samples : int, optional 20 | The number of samples to generate in the dataset (default is 100). 21 | num_nodes_per_sample : int, optional 22 | The number of nodes per sample (default is 10). 23 | min_links_per_node : int, optional 24 | The minimum number of links per node (default is 1). 25 | max_links_per_node : int, optional 26 | The maximum number of links per node (default is 3). 27 | depth : int, optional 28 | The depth of the DAG (default is 3). 29 | num_children_per_node : int, optional 30 | The number of children per node (default is 2). 31 | extra_links_per_node : int, optional 32 | The number of extra links per node (default is 1). 33 | add_rand_desc : int, optional 34 | The number of random descriptions to add (default is 0). 35 | delete_desc : int, optional 36 | The number of descriptions to delete (default is 0). 37 | add_cycles : int, optional 38 | The number of cycles to add to the DAG (default is 0). 39 | num_dags : int, optional 40 | The number of DAGs to generate for linear equations (default is 1). 41 | 42 | Methods: 43 | -------- 44 | __len__() 45 | Returns the number of samples in the dataset. 46 | __getitem__(key) 47 | Retrieves a specific sample from the dataset. 48 | create_dataset() 49 | Generates the dataset based on the specified parameters. 50 | get_fewshot_examples(shots) 51 | Generates few-shot examples for the dataset. 52 | _generate_sample(**kwargs) 53 | Generates a single sample for the dataset. 54 | """ 55 | 56 | def __init__(self, 57 | dataset_type, 58 | is_trainset=False, 59 | num_samples=100, 60 | num_nodes_per_sample=10, 61 | min_links_per_node=1, 62 | max_links_per_node=3, 63 | depth=3, 64 | num_children_per_node=2, 65 | extra_links_per_node=1, 66 | add_rand_desc=0, 67 | delete_desc=0, 68 | add_cycles=0, 69 | num_dags=1, 70 | ): 71 | 72 | self.dataset_type = dataset_type 73 | self.is_trainset = is_trainset 74 | self.num_samples = num_samples 75 | 76 | self.num_nodes_per_sample = num_nodes_per_sample 77 | self.min_links_per_node = min_links_per_node 78 | self.max_links_per_node = max_links_per_node 79 | 80 | self.depth = depth 81 | self.num_children_per_node = num_children_per_node 82 | self.extra_links_per_node = extra_links_per_node 83 | 84 | self.add_rand_desc = add_rand_desc 85 | self.delete_desc = delete_desc 86 | self.add_cycles = add_cycles 87 | 88 | self.num_dags = num_dags 89 | 90 | self.data = self.create_dataset() 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | def __getitem__(self, key): 96 | return self.data[key] 97 | 98 | def create_dataset(self): 99 | data = {} 100 | # data = {"descriptions": {}, "answers": []} 101 | for _ in tqdm(range(self.num_samples)): 102 | sample = self._generate_sample() 103 | processed = {} 104 | # first record all the keys except "descriptions" 105 | for key in sample.keys(): 106 | if key != "descriptions": 107 | processed[key] = sample[key] 108 | 109 | for key in sample.keys(): 110 | if key == "descriptions": 111 | for order, desc in sample[key].items(): 112 | if order not in data: 113 | data[order] = [] 114 | new = processed.copy() 115 | new["descriptions"] = desc 116 | data[order].append(new) 117 | 118 | return data 119 | 120 | def get_fewshot_examples(self, shots): 121 | if shots == 1: 122 | examples = f"\n\nHere is an example of problem related to {self.dataset_type} task and their corresponding inference steps." 123 | else: 124 | examples = f"\n\nHere are {shots} examples of problems related to {self.dataset_type} task and their corresponding inference steps." 125 | if self.dataset_type in ["linear_equation"]: 126 | shots = 1 127 | if self.dataset_type in ["max_sum_path"]: 128 | shots = 2 129 | for _ in range(shots): 130 | if self.dataset_type in ["linear_equation"]: 131 | depth = 2 132 | else: 133 | depth = 3 134 | sample = self._generate_sample(is_trainset=True, num_nodes_per_sample=7, min_links_per_node=1, max_links_per_node=3, depth=depth, num_children_per_node=2, extra_links_per_node=0, add_rand_desc=0, delete_desc=0, add_cycles=0) 135 | examples += "\n\nQ:\n" + sample["descriptions"]["random"] 136 | examples += "\n\nA:\n" + sample["inferences"] 137 | 138 | return examples 139 | 140 | def _generate_sample(self, **kwargs): 141 | sample = {} 142 | dataset_type = kwargs.get("dataset_type", self.dataset_type) 143 | is_trainset = kwargs.get("is_trainset", self.is_trainset) 144 | num_nodes_per_sample = kwargs.get("num_nodes_per_sample", self.num_nodes_per_sample) 145 | min_links_per_node = kwargs.get("min_links_per_node", self.min_links_per_node) 146 | max_links_per_node = kwargs.get("max_links_per_node", self.max_links_per_node) 147 | depth = kwargs.get("depth", self.depth) 148 | num_children_per_node = kwargs.get("num_children_per_node", self.num_children_per_node) 149 | extra_links_per_node = kwargs.get("extra_links_per_node", self.extra_links_per_node) 150 | add_rand_desc = kwargs.get("add_rand_desc", self.add_rand_desc) 151 | delete_desc = kwargs.get("delete_desc", self.delete_desc) 152 | add_cycles = kwargs.get("add_cycles", self.add_cycles) 153 | num_dags = kwargs.get("num_dags", self.num_dags) 154 | 155 | if dataset_type in ["arithmetic", "bool_logic"]: 156 | from .DAG.math_dag import ArithmeticDAG, ArithmeticDAGDescriber 157 | from .DAG.logic_dag import BoolDAG, BoolDAGDescriber 158 | 159 | if dataset_type == "arithmetic": 160 | ops = ["+", "-", "*", "/", "sqrt", "**"] 161 | uni_ops = ["sqrt", "**"] 162 | DAGType, DAGDescriber = ArithmeticDAG, ArithmeticDAGDescriber 163 | 164 | elif dataset_type == "bool_logic": 165 | ops = ['and', 'or', 'not'] 166 | uni_ops = ['not'] 167 | DAGType, DAGDescriber = BoolDAG, BoolDAGDescriber 168 | 169 | dag = DAGType(ops, uni_ops, depth, num_children_per_node, extra_links_per_node, add_cycles) 170 | describer = DAGDescriber(dag, ops, uni_ops, add_rand_desc, delete_desc) 171 | 172 | elif dataset_type == "linear_equation": 173 | from .DAG.math_dag import LinearEq 174 | ops = ["+", "-", "*", "/", "sqrt", "**"] 175 | uni_ops = ["sqrt", "**"] 176 | describer = LinearEq(ops, uni_ops, depth, num_dags, num_children_per_node, extra_links_per_node, add_rand_desc) 177 | 178 | elif dataset_type in ["deductive_logic", "abductive_logic"]: 179 | from .DAG.logic_dag import DeductionDAG, DeductionDAGDescriber, AbductionDAG, AbductionDAGDescriber 180 | if dataset_type == "deductive_logic": 181 | DAGType, DAGDescriber = DeductionDAG, DeductionDAGDescriber 182 | probs = [0.2, 0.6, 0.2] 183 | elif dataset_type == "abductive_logic": 184 | DAGType, DAGDescriber = AbductionDAG, AbductionDAGDescriber 185 | probs = [0.07, 0.66, 0.27] 186 | 187 | ops = ['and', 'or', 'not'] 188 | uni_ops = ['not'] 189 | dag = DAGType(ops, uni_ops, depth, probs, num_children_per_node) 190 | describer = DAGDescriber(dag, ops, uni_ops, add_rand_desc) 191 | 192 | elif dataset_type in ["reachability", "max_sum_path"]: 193 | from .DAG.code_dag import CodeDAG, CodeDAGDescriber 194 | # CodeDAG is not allowed to add cycles and delete descriptions 195 | dag = CodeDAG(num_nodes_per_sample, min_links_per_node, max_links_per_node) 196 | describer = CodeDAGDescriber(dag, dataset_type, add_rand_desc) 197 | 198 | question = describer.describe_question() 199 | answer = describer.describe_answer() 200 | sample["descriptions"] = question 201 | sample["answers"] = answer 202 | if dataset_type in ["arithmetic", "bool_logic", "deductive_logic"]: 203 | sample["vars"] = dag.root.name 204 | 205 | if is_trainset: 206 | inference_desc = describer.describe_inference_steps() 207 | sample["inferences"] = inference_desc 208 | sample = process_dyval_training_sample(sample, dataset_type) 209 | 210 | return sample -------------------------------------------------------------------------------- /promptbench/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .eval import Eval -------------------------------------------------------------------------------- /promptbench/metrics/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is copied from Huggingface evaluate library. 3 | """ 4 | 5 | """ BLEU metric. """ 6 | 7 | import datasets 8 | 9 | import evaluate 10 | 11 | from .bleu_ import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py 12 | from .tokenizer_13a import Tokenizer13a 13 | 14 | _CITATION = """\ 15 | @INPROCEEDINGS{Papineni02bleu:a, 16 | author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu}, 17 | title = {BLEU: a Method for Automatic Evaluation of Machine Translation}, 18 | booktitle = {}, 19 | year = {2002}, 20 | pages = {311--318} 21 | } 22 | @inproceedings{lin-och-2004-orange, 23 | title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation", 24 | author = "Lin, Chin-Yew and 25 | Och, Franz Josef", 26 | booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics", 27 | month = "aug 23{--}aug 27", 28 | year = "2004", 29 | address = "Geneva, Switzerland", 30 | publisher = "COLING", 31 | url = "https://www.aclweb.org/anthology/C04-1072", 32 | pages = "501--507", 33 | } 34 | """ 35 | 36 | _DESCRIPTION = """\ 37 | BLEU (Bilingual Evaluation Understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another. 38 | Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation, the better it is" 39 | – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and remains one of the most popular automated and inexpensive metrics. 40 | Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations. 41 | Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. 42 | Neither intelligibility nor grammatical correctness are not taken into account. 43 | """ 44 | 45 | _KWARGS_DESCRIPTION = """ 46 | Computes BLEU score of translated segments against one or more references. 47 | Args: 48 | predictions: list of translations to score. 49 | references: list of lists of or just a list of references for each translation. 50 | tokenizer : approach used for tokenizing `predictions` and `references`. 51 | The default tokenizer is `tokenizer_13a`, a minimal tokenization approach that is equivalent to `mteval-v13a`, used by WMT. 52 | This can be replaced by any function that takes a string as input and returns a list of tokens as output. 53 | max_order: Maximum n-gram order to use when computing BLEU score. 54 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 55 | Returns: 56 | 'bleu': bleu score, 57 | 'precisions': geometric mean of n-gram precisions, 58 | 'brevity_penalty': brevity penalty, 59 | 'length_ratio': ratio of lengths, 60 | 'translation_length': translation_length, 61 | 'reference_length': reference_length 62 | Examples: 63 | >>> predictions = ["hello there general kenobi", "foo bar foobar"] 64 | >>> references = [ 65 | ... ["hello there general kenobi", "hello there!"], 66 | ... ["foo bar foobar"] 67 | ... ] 68 | >>> bleu = evaluate.load("bleu") 69 | >>> results = bleu.compute(predictions=predictions, references=references) 70 | >>> print(results["bleu"]) 71 | 1.0 72 | """ 73 | 74 | 75 | class Bleu(evaluate.Metric): 76 | def _info(self): 77 | return evaluate.MetricInfo( 78 | description=_DESCRIPTION, 79 | citation=_CITATION, 80 | inputs_description=_KWARGS_DESCRIPTION, 81 | features=[ 82 | datasets.Features( 83 | { 84 | "predictions": datasets.Value("string", id="sequence"), 85 | "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), 86 | } 87 | ), 88 | datasets.Features( 89 | { 90 | "predictions": datasets.Value("string", id="sequence"), 91 | "references": datasets.Value("string", id="sequence"), 92 | } 93 | ), 94 | ], 95 | codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"], 96 | reference_urls=[ 97 | "https://en.wikipedia.org/wiki/BLEU", 98 | "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", 99 | ], 100 | ) 101 | 102 | def _compute(self, predictions, references, tokenizer=Tokenizer13a(), max_order=4, smooth=False): 103 | # if only one reference is provided make sure we still use list of lists 104 | if isinstance(references[0], str): 105 | references = [[ref] for ref in references] 106 | 107 | references = [[tokenizer(r) for r in ref] for ref in references] 108 | predictions = [tokenizer(p) for p in predictions] 109 | score = compute_bleu( 110 | reference_corpus=references, translation_corpus=predictions, max_order=max_order, smooth=smooth 111 | ) 112 | (bleu, precisions, bp, ratio, translation_length, reference_length) = score 113 | return { 114 | "bleu": bleu, 115 | "precisions": precisions, 116 | "brevity_penalty": bp, 117 | "length_ratio": ratio, 118 | "translation_length": translation_length, 119 | "reference_length": reference_length, 120 | } -------------------------------------------------------------------------------- /promptbench/metrics/bleu/bleu_.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is copied from Huggingface evaluate library. 3 | """ 4 | # Copyright 2017 Google Inc. All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | """Python implementation of BLEU and smooth-BLEU. 20 | This module provides a Python implementation of BLEU and smooth-BLEU. 21 | Smooth BLEU is computed following the method outlined in the paper: 22 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 23 | evaluation metrics for machine translation. COLING 2004. 24 | """ 25 | 26 | import collections 27 | import math 28 | 29 | 30 | def _get_ngrams(segment, max_order): 31 | """Extracts all n-grams upto a given maximum order from an input segment. 32 | Args: 33 | segment: text segment from which n-grams will be extracted. 34 | max_order: maximum length in tokens of the n-grams returned by this 35 | methods. 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i+order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | Args: 52 | reference_corpus: list of lists of references for each translation. Each 53 | reference should be tokenized into a list of tokens. 54 | translation_corpus: list of translations to score. Each translation 55 | should be tokenized into a list of tokens. 56 | max_order: Maximum n-gram order to use when computing BLEU score. 57 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 58 | Returns: 59 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 60 | precisions and brevity penalty. 61 | """ 62 | matches_by_order = [0] * max_order 63 | possible_matches_by_order = [0] * max_order 64 | reference_length = 0 65 | translation_length = 0 66 | for (references, translation) in zip(reference_corpus, translation_corpus): 67 | reference_length += min(len(r) for r in references) 68 | translation_length += len(translation) 69 | 70 | merged_ref_ngram_counts = collections.Counter() 71 | for reference in references: 72 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 73 | translation_ngram_counts = _get_ngrams(translation, max_order) 74 | overlap = translation_ngram_counts & merged_ref_ngram_counts 75 | for ngram in overlap: 76 | matches_by_order[len(ngram)-1] += overlap[ngram] 77 | for order in range(1, max_order+1): 78 | possible_matches = len(translation) - order + 1 79 | if possible_matches > 0: 80 | possible_matches_by_order[order-1] += possible_matches 81 | 82 | precisions = [0] * max_order 83 | for i in range(0, max_order): 84 | if smooth: 85 | precisions[i] = ((matches_by_order[i] + 1.) / 86 | (possible_matches_by_order[i] + 1.)) 87 | else: 88 | if possible_matches_by_order[i] > 0: 89 | precisions[i] = (float(matches_by_order[i]) / 90 | possible_matches_by_order[i]) 91 | else: 92 | precisions[i] = 0.0 93 | 94 | if min(precisions) > 0: 95 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 96 | geo_mean = math.exp(p_log_sum) 97 | else: 98 | geo_mean = 0 99 | 100 | ratio = float(translation_length) / reference_length 101 | 102 | if ratio > 1.0: 103 | bp = 1. 104 | else: 105 | bp = math.exp(1 - 1. / ratio) 106 | 107 | bleu = geo_mean * bp 108 | 109 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 110 | -------------------------------------------------------------------------------- /promptbench/metrics/bleu/tokenizer_13a.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is copied from Huggingface evaluate library. 3 | """ 4 | 5 | # Source: https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_13a.py 6 | # Copyright 2020 SacreBLEU Authors. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import re 21 | from functools import lru_cache 22 | 23 | 24 | class BaseTokenizer: 25 | """A base dummy tokenizer to derive from.""" 26 | 27 | def signature(self): 28 | """ 29 | Returns a signature for the tokenizer. 30 | :return: signature string 31 | """ 32 | return "none" 33 | 34 | def __call__(self, line): 35 | """ 36 | Tokenizes an input line with the tokenizer. 37 | :param line: a segment to tokenize 38 | :return: the tokenized line 39 | """ 40 | return line 41 | 42 | 43 | class TokenizerRegexp(BaseTokenizer): 44 | def signature(self): 45 | return "re" 46 | 47 | def __init__(self): 48 | self._re = [ 49 | # language-dependent part (assuming Western languages) 50 | (re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "), 51 | # tokenize period and comma unless preceded by a digit 52 | (re.compile(r"([^0-9])([\.,])"), r"\1 \2 "), 53 | # tokenize period and comma unless followed by a digit 54 | (re.compile(r"([\.,])([^0-9])"), r" \1 \2"), 55 | # tokenize dash when preceded by a digit 56 | (re.compile(r"([0-9])(-)"), r"\1 \2 "), 57 | # one space only between words 58 | # NOTE: Doing this in Python (below) is faster 59 | # (re.compile(r'\s+'), r' '), 60 | ] 61 | 62 | @lru_cache(maxsize=2**16) 63 | def __call__(self, line): 64 | """Common post-processing tokenizer for `13a` and `zh` tokenizers. 65 | :param line: a segment to tokenize 66 | :return: the tokenized line 67 | """ 68 | for (_re, repl) in self._re: 69 | line = _re.sub(repl, line) 70 | 71 | # no leading or trailing spaces, single space within words 72 | # return ' '.join(line.split()) 73 | # This line is changed with regards to the original tokenizer (seen above) to return individual words 74 | return line.split() 75 | 76 | 77 | class Tokenizer13a(BaseTokenizer): 78 | def signature(self): 79 | return "13a" 80 | 81 | def __init__(self): 82 | self._post_tokenizer = TokenizerRegexp() 83 | 84 | @lru_cache(maxsize=2**16) 85 | def __call__(self, line): 86 | """Tokenizes an input line using a relatively minimal tokenization 87 | that is however equivalent to mteval-v13a, used by WMT. 88 | :param line: a segment to tokenize 89 | :return: the tokenized line 90 | """ 91 | 92 | # language-independent part: 93 | line = line.replace("", "") 94 | line = line.replace("-\n", "") 95 | line = line.replace("\n", " ") 96 | 97 | if "&" in line: 98 | line = line.replace(""", '"') 99 | line = line.replace("&", "&") 100 | line = line.replace("<", "<") 101 | line = line.replace(">", ">") 102 | 103 | return self._post_tokenizer(f" {line} ") -------------------------------------------------------------------------------- /promptbench/metrics/cider/cider.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is copied from https://github.com/tylin/coco-caption/tree/master/pycocoevalcap/cider. 3 | """ 4 | # Filename: cider.py 5 | # 6 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 7 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 8 | # 9 | # Creation Date: Sun Feb 8 14:16:54 2015 10 | # 11 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 12 | 13 | from .cider_scorer import CiderScorer 14 | import pdb 15 | 16 | class Cider: 17 | """ 18 | Main Class to compute the CIDEr metric 19 | 20 | """ 21 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 22 | # set cider to sum over 1 to 4-grams 23 | self._n = n 24 | # set the standard deviation parameter for gaussian penalty 25 | self._sigma = sigma 26 | 27 | def compute_score(self, gts, res): 28 | """ 29 | Main function to compute CIDEr score 30 | :param hypo_for_image (dict) : dictionary with key and value 31 | ref_for_image (dict) : dictionary with key and value 32 | :return: cider (float) : computed CIDEr score for the corpus 33 | """ 34 | 35 | assert(gts.keys() == res.keys()) 36 | imgIds = gts.keys() 37 | 38 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 39 | 40 | for id in imgIds: 41 | hypo = res[id] 42 | ref = gts[id] 43 | 44 | # Sanity check. 45 | assert(type(hypo) is list) 46 | assert(len(hypo) == 1) 47 | assert(type(ref) is list) 48 | assert(len(ref) > 0) 49 | 50 | cider_scorer += (hypo[0], ref) 51 | 52 | (score, scores) = cider_scorer.compute_score() 53 | 54 | return score, scores 55 | 56 | def method(self): 57 | return "CIDEr" -------------------------------------------------------------------------------- /promptbench/metrics/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is copied from https://github.com/tylin/coco-caption/tree/master/pycocoevalcap/cider. 3 | """ 4 | #!/usr/bin/env python 5 | # Tsung-Yi Lin 6 | # Ramakrishna Vedantam 7 | 8 | import copy 9 | from collections import defaultdict 10 | import numpy as np 11 | import pdb 12 | import math 13 | 14 | def precook(s, n=4, out=False): 15 | """ 16 | Takes a string as input and returns an object that can be given to 17 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 18 | can take string arguments as well. 19 | :param s: string : sentence to be converted into ngrams 20 | :param n: int : number of ngrams for which representation is calculated 21 | :return: term frequency vector for occuring ngrams 22 | """ 23 | words = s.split() 24 | counts = defaultdict(int) 25 | for k in range(1,n+1): 26 | for i in range(len(words)-k+1): 27 | ngram = tuple(words[i:i+k]) 28 | counts[ngram] += 1 29 | return counts 30 | 31 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 32 | '''Takes a list of reference sentences for a single segment 33 | and returns an object that encapsulates everything that BLEU 34 | needs to know about them. 35 | :param refs: list of string : reference sentences for some image 36 | :param n: int : number of ngrams for which (ngram) representation is calculated 37 | :return: result (list of dict) 38 | ''' 39 | return [precook(ref, n) for ref in refs] 40 | 41 | def cook_test(test, n=4): 42 | '''Takes a test sentence and returns an object that 43 | encapsulates everything that BLEU needs to know about it. 44 | :param test: list of string : hypothesis sentence for some image 45 | :param n: int : number of ngrams for which (ngram) representation is calculated 46 | :return: result (dict) 47 | ''' 48 | return precook(test, n, True) 49 | 50 | class CiderScorer(object): 51 | """CIDEr scorer. 52 | """ 53 | 54 | def copy(self): 55 | ''' copy the refs.''' 56 | new = CiderScorer(n=self.n) 57 | new.ctest = copy.copy(self.ctest) 58 | new.crefs = copy.copy(self.crefs) 59 | return new 60 | 61 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 62 | ''' singular instance ''' 63 | self.n = n 64 | self.sigma = sigma 65 | self.crefs = [] 66 | self.ctest = [] 67 | self.document_frequency = defaultdict(float) 68 | self.cook_append(test, refs) 69 | self.ref_len = None 70 | 71 | def cook_append(self, test, refs): 72 | '''called by constructor and __iadd__ to avoid creating new instances.''' 73 | 74 | if refs is not None: 75 | self.crefs.append(cook_refs(refs)) 76 | if test is not None: 77 | self.ctest.append(cook_test(test)) ## N.B.: -1 78 | else: 79 | self.ctest.append(None) # lens of crefs and ctest have to match 80 | 81 | def size(self): 82 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 83 | return len(self.crefs) 84 | 85 | def __iadd__(self, other): 86 | '''add an instance (e.g., from another sentence).''' 87 | 88 | if type(other) is tuple: 89 | ## avoid creating new CiderScorer instances 90 | self.cook_append(other[0], other[1]) 91 | else: 92 | self.ctest.extend(other.ctest) 93 | self.crefs.extend(other.crefs) 94 | 95 | return self 96 | def compute_doc_freq(self): 97 | ''' 98 | Compute term frequency for reference data. 99 | This will be used to compute idf (inverse document frequency later) 100 | The term frequency is stored in the object 101 | :return: None 102 | ''' 103 | for refs in self.crefs: 104 | # refs, k ref captions of one image 105 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 106 | self.document_frequency[ngram] += 1 107 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 108 | 109 | def compute_cider(self): 110 | def counts2vec(cnts): 111 | """ 112 | Function maps counts of ngram to vector of tfidf weights. 113 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 114 | The n-th entry of array denotes length of n-grams. 115 | :param cnts: 116 | :return: vec (array of dict), norm (array of float), length (int) 117 | """ 118 | vec = [defaultdict(float) for _ in range(self.n)] 119 | length = 0 120 | norm = [0.0 for _ in range(self.n)] 121 | for (ngram, term_freq) in cnts.items(): 122 | # give word count 1 if it doesn't appear in reference corpus 123 | df = np.log(max(1.0, self.document_frequency[ngram])) 124 | # ngram index 125 | n = len(ngram)-1 126 | # tf (term_freq) * idf (precomputed idf) for n-grams 127 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 128 | # compute norm for the vector. the norm will be used for computing similarity 129 | norm[n] += pow(vec[n][ngram], 2) 130 | 131 | if n == 1: 132 | length += term_freq 133 | norm = [np.sqrt(n) for n in norm] 134 | return vec, norm, length 135 | 136 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 137 | ''' 138 | Compute the cosine similarity of two vectors. 139 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 140 | :param vec_ref: array of dictionary for vector corresponding to reference 141 | :param norm_hyp: array of float for vector corresponding to hypothesis 142 | :param norm_ref: array of float for vector corresponding to reference 143 | :param length_hyp: int containing length of hypothesis 144 | :param length_ref: int containing length of reference 145 | :return: array of score for each n-grams cosine similarity 146 | ''' 147 | delta = float(length_hyp - length_ref) 148 | # measure consine similarity 149 | val = np.array([0.0 for _ in range(self.n)]) 150 | for n in range(self.n): 151 | # ngram 152 | for (ngram,count) in vec_hyp[n].items(): 153 | # vrama91 : added clipping 154 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 155 | 156 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 157 | val[n] /= (norm_hyp[n]*norm_ref[n]) 158 | 159 | assert(not math.isnan(val[n])) 160 | # vrama91: added a length based gaussian penalty 161 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 162 | return val 163 | 164 | # compute log reference length 165 | self.ref_len = np.log(float(len(self.crefs))) 166 | if len(self.crefs) == 1: 167 | self.ref_len = 1 168 | scores = [] 169 | for test, refs in zip(self.ctest, self.crefs): 170 | # compute vector for test captions 171 | vec, norm, length = counts2vec(test) 172 | # compute vector for ref captions 173 | score = np.array([0.0 for _ in range(self.n)]) 174 | for ref in refs: 175 | vec_ref, norm_ref, length_ref = counts2vec(ref) 176 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 177 | # change by vrama91 - mean of ngram scores, instead of sum 178 | score_avg = np.mean(score) 179 | # divide by number of references 180 | score_avg /= len(refs) 181 | # multiply score by 10 182 | score_avg *= 10.0 183 | # append score of an image to the score list 184 | scores.append(score_avg) 185 | return scores 186 | 187 | def compute_score(self, option=None, verbose=0): 188 | # compute idf 189 | self.compute_doc_freq() 190 | # assert to check document frequency 191 | assert(len(self.ctest) >= max(self.document_frequency.values())) 192 | # compute cider score 193 | score = self.compute_cider() 194 | # debug 195 | # print score 196 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /promptbench/metrics/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | class Eval: 5 | """ 6 | A utility class for computing various evaluation metrics. 7 | 8 | This class provides static methods to compute metrics such as classification accuracy, SQuAD V2 F1 score, BLEU score, and math accuracy. 9 | 10 | Methods: 11 | -------- 12 | compute_cls_accuracy(preds, gts) 13 | Computes classification accuracy. 14 | compute_squad_v2_f1(preds, gts, dataset) 15 | Computes the F1 score for the SQuAD V2 dataset. 16 | compute_bleu(preds, gts) 17 | Computes the BLEU score for translation tasks. 18 | compute_math_accuracy(dataset, preds, gts) 19 | Computes accuracy for math dataset. 20 | """ 21 | 22 | @staticmethod 23 | def compute_cls_accuracy(preds, gts): 24 | """ 25 | Computes classification accuracy based on predictions and ground truths. 26 | 27 | Parameters: 28 | ----------- 29 | preds : list 30 | A list of predictions. 31 | gts : list 32 | A list of ground truths. 33 | 34 | Returns: 35 | -------- 36 | float 37 | The classification accuracy. 38 | """ 39 | try: 40 | preds = [str(pred).lower() for pred in preds] 41 | gts = [str(gt).lower() for gt in gts] 42 | except AttributeError: 43 | print("Something in either preds or gts can not be convert to a string.") 44 | 45 | if not isinstance(preds, list): 46 | preds = [preds] 47 | gts = [gts] 48 | 49 | return sum(a == b for a, b in zip(preds, gts)) / len(preds) 50 | 51 | @staticmethod 52 | def compute_squad_v2_f1(preds, gts, dataset): 53 | """ 54 | Computes the F1 score for the SQuAD V2 dataset. 55 | 56 | Parameters: 57 | ----------- 58 | preds : list 59 | A list of predictions. 60 | gts : list 61 | A list of ground truth IDs. 62 | dataset : list 63 | The dataset containing the SQuAD V2 data. 64 | 65 | Returns: 66 | -------- 67 | float 68 | The F1 score for the SQuAD V2 dataset. 69 | """ 70 | from .squad_v2.squad_v2 import SquadV2 71 | metric = SquadV2() 72 | 73 | model_output = [] 74 | for id, pred in zip(gts, preds): 75 | no_ans_prob = 1 if pred == "unanswerable" else 0 76 | pred = "" if pred == "unanswerable" else pred 77 | model_output.append({"id": id, "prediction_text": pred, "no_answer_probability": no_ans_prob}) 78 | 79 | references = [{"answers": data["answers"], "id": data["id"]} for data in dataset] 80 | 81 | score = metric.compute(predictions=model_output, references=references) 82 | 83 | return score["f1"] 84 | 85 | @staticmethod 86 | def compute_bleu(preds, gts): 87 | """ 88 | Computes the BLEU score for translation tasks. 89 | 90 | Parameters: 91 | ----------- 92 | preds : list 93 | A list of predictions. 94 | gts : list 95 | A list of ground truth translations. 96 | 97 | Returns: 98 | -------- 99 | float 100 | The BLEU score. 101 | """ 102 | from .bleu.bleu import Bleu 103 | metric = Bleu() 104 | results = metric.compute(predictions=preds, references=gts) 105 | return results['bleu'] 106 | 107 | @staticmethod 108 | def compute_math_accuracy(preds, gts): 109 | """ 110 | Computes accuracy for the 'math' dataset. 111 | 112 | Parameters: 113 | ----------- 114 | dataset : list 115 | The dataset containing math data. 116 | preds : list 117 | A list of predictions. 118 | gts : list 119 | A list of ground truths. 120 | 121 | Returns: 122 | -------- 123 | float 124 | The math accuracy. 125 | """ 126 | processed_preds = [] 127 | processed_gts = [] 128 | 129 | for pred, gt in zip(preds, gts): 130 | pred = "True" if pred.lower() == "yes" else ("False" if pred.lower() == "no" else pred) 131 | gt = str(gt).lower() 132 | processed_preds.append(pred.lower()) 133 | processed_gts.append(gt) 134 | 135 | return sum(a == b for a, b in zip(processed_preds, processed_gts)) / len(processed_gts) 136 | 137 | @staticmethod 138 | def compute_vqa_accuracy(preds, gts): 139 | """ 140 | Computes vqa accuracy for the VQAv2 dataset. 141 | 142 | Parameters: 143 | ----------- 144 | preds : list 145 | A list of predictions. 146 | gts : list 147 | A list of answers. 148 | 149 | Returns: 150 | -------- 151 | float 152 | The vqa accuracy. 153 | """ 154 | from .vqa.eval_vqa import VQAEval 155 | metric = VQAEval(n=3) 156 | dict_gts = {i: {"answers": val} for i, val in enumerate(gts)} 157 | dict_preds = {i: {"answer": val} for i, val in enumerate(preds)} 158 | score = metric.evaluate(dict_gts, dict_preds, list(range(len(preds)))) 159 | return score 160 | 161 | @staticmethod 162 | def compute_cider(preds, gts): 163 | """ 164 | Computes the CIDEr score for image captioning tasks. 165 | 166 | Parameters: 167 | ----------- 168 | preds : list 169 | A list of predictions. 170 | gts : list 171 | A list of ground truth captions. 172 | 173 | Returns: 174 | -------- 175 | float 176 | The CIDEr score. 177 | """ 178 | from .cider.cider import Cider 179 | metric = Cider() 180 | dict_gts = {i: val for i, val in enumerate(gts)} 181 | dict_preds = {i: [val] for i, val in enumerate(preds)} 182 | score, _ = metric.compute_score(gts=dict_gts, res=dict_preds) 183 | return score -------------------------------------------------------------------------------- /promptbench/metrics/squad_v2/squad_v2.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is copied from Huggingface evaluate library. 3 | """ 4 | 5 | # Copyright 2020 The HuggingFace Evaluate Authors. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | """ SQuAD v2 metric. """ 19 | 20 | import datasets 21 | 22 | import evaluate 23 | 24 | from .compute_score import ( 25 | apply_no_ans_threshold, 26 | find_all_best_thresh, 27 | get_raw_scores, 28 | make_eval_dict, 29 | make_qid_to_has_ans, 30 | merge_eval, 31 | ) 32 | 33 | 34 | _CITATION = """\ 35 | @inproceedings{Rajpurkar2016SQuAD10, 36 | title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text}, 37 | author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang}, 38 | booktitle={EMNLP}, 39 | year={2016} 40 | } 41 | """ 42 | 43 | _DESCRIPTION = """ 44 | This metric wrap the official scoring script for version 2 of the Stanford Question 45 | Answering Dataset (SQuAD). 46 | 47 | Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by 48 | crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, 49 | from the corresponding reading passage, or the question might be unanswerable. 50 | 51 | SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions 52 | written adversarially by crowdworkers to look similar to answerable ones. 53 | To do well on SQuAD2.0, systems must not only answer questions when possible, but also 54 | determine when no answer is supported by the paragraph and abstain from answering. 55 | """ 56 | 57 | _KWARGS_DESCRIPTION = """ 58 | Computes SQuAD v2 scores (F1 and EM). 59 | Args: 60 | predictions: List of triple for question-answers to score with the following elements: 61 | - the question-answer 'id' field as given in the references (see below) 62 | - the text of the answer 63 | - the probability that the question has no answer 64 | references: List of question-answers dictionaries with the following key-values: 65 | - 'id': id of the question-answer pair (see above), 66 | - 'answers': a list of Dict {'text': text of the answer as a string} 67 | no_answer_threshold: float 68 | Probability threshold to decide that a question has no answer. 69 | Returns: 70 | 'exact': Exact match (the normalized answer exactly match the gold answer) 71 | 'f1': The F-score of predicted tokens versus the gold answer 72 | 'total': Number of score considered 73 | 'HasAns_exact': Exact match (the normalized answer exactly match the gold answer) 74 | 'HasAns_f1': The F-score of predicted tokens versus the gold answer 75 | 'HasAns_total': Number of score considered 76 | 'NoAns_exact': Exact match (the normalized answer exactly match the gold answer) 77 | 'NoAns_f1': The F-score of predicted tokens versus the gold answer 78 | 'NoAns_total': Number of score considered 79 | 'best_exact': Best exact match (with varying threshold) 80 | 'best_exact_thresh': No-answer probability threshold associated to the best exact match 81 | 'best_f1': Best F1 (with varying threshold) 82 | 'best_f1_thresh': No-answer probability threshold associated to the best F1 83 | Examples: 84 | 85 | >>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}] 86 | >>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}] 87 | >>> squad_v2_metric = evaluate.load("squad_v2") 88 | >>> results = squad_v2_metric.compute(predictions=predictions, references=references) 89 | >>> print(results) 90 | {'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0} 91 | """ 92 | 93 | 94 | @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 95 | class SquadV2(evaluate.Metric): 96 | def _info(self): 97 | return evaluate.MetricInfo( 98 | description=_DESCRIPTION, 99 | citation=_CITATION, 100 | inputs_description=_KWARGS_DESCRIPTION, 101 | features=datasets.Features( 102 | { 103 | "predictions": { 104 | "id": datasets.Value("string"), 105 | "prediction_text": datasets.Value("string"), 106 | "no_answer_probability": datasets.Value("float32"), 107 | }, 108 | "references": { 109 | "id": datasets.Value("string"), 110 | "answers": datasets.features.Sequence( 111 | {"text": datasets.Value("string"), "answer_start": datasets.Value("int32")} 112 | ), 113 | }, 114 | } 115 | ), 116 | codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], 117 | reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], 118 | ) 119 | 120 | def _compute(self, predictions, references, no_answer_threshold=1.0): 121 | no_answer_probabilities = {p["id"]: p["no_answer_probability"] for p in predictions} 122 | dataset = [{"paragraphs": [{"qas": references}]}] 123 | predictions = {p["id"]: p["prediction_text"] for p in predictions} 124 | 125 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 126 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 127 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 128 | 129 | exact_raw, f1_raw = get_raw_scores(dataset, predictions) 130 | exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) 131 | f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) 132 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 133 | 134 | if has_ans_qids: 135 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 136 | merge_eval(out_eval, has_ans_eval, "HasAns") 137 | if no_ans_qids: 138 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 139 | merge_eval(out_eval, no_ans_eval, "NoAns") 140 | find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans) 141 | return dict(out_eval) -------------------------------------------------------------------------------- /promptbench/mpa/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/promptbench/fcda538bd779ad11612818e0645a387a462b5c3b/promptbench/mpa/.DS_Store -------------------------------------------------------------------------------- /promptbench/mpa/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .agent import * 5 | from .dataprocess import * 6 | from .mpa_prompts import * 7 | 8 | 9 | MPA_DATASETS = [ 10 | "mmlu", 11 | "arc-challenge", 12 | "gsm8k", 13 | "formal_fallacies", 14 | "temporal_sequences", 15 | "object_counting", 16 | ] -------------------------------------------------------------------------------- /promptbench/mpa/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from ..models import LLMModel 5 | import copy 6 | from .dataprocess import ParaphraserInputProcess, ParaphraserOutputProcess, EvaluatorInputProcess, EvaluatorOutputProcess 7 | 8 | 9 | class ParaphraserAgent(object): 10 | def __init__(self, model: LLMModel, prompt: str, input_func: ParaphraserInputProcess, output_func: ParaphraserOutputProcess): 11 | self.model = model 12 | self.prompt = prompt 13 | self.input_func = input_func 14 | self.output_func = output_func 15 | 16 | def __call__(self, data: dict) -> str: 17 | input_text = self.input_func(self.prompt, data) 18 | 19 | while True: 20 | try_count = 1 21 | try: 22 | output_text = self.model(input_text) 23 | except Exception as e: 24 | print("ParaphraserAgent output error!") 25 | print(e) 26 | import time 27 | time.sleep(30*try_count) 28 | try_count += 1 29 | continue 30 | 31 | try: 32 | paraphrased_data = self.output_func(output_text, data) 33 | break 34 | except Exception as e: 35 | print("ParaphraserAgent process output error!") 36 | print(e) 37 | continue 38 | 39 | return paraphrased_data 40 | 41 | 42 | class EvaluatorAgent(object): 43 | def __init__(self, model: LLMModel, prompt: str, input_func: EvaluatorInputProcess, output_func: EvaluatorOutputProcess): 44 | self.model = model 45 | self.prompt = prompt 46 | self.input_func = input_func 47 | self.output_func = output_func 48 | 49 | def __call__(self, original_data, paraphraserd_data): 50 | input_text = self.input_func(self.prompt, original_data, paraphraserd_data) 51 | 52 | while True: 53 | try_count = 1 54 | try: 55 | output_text = self.model(input_text) 56 | except Exception as e: 57 | print("EvaluatorAgent output error!") 58 | print(e) 59 | import time 60 | time.sleep(30*try_count) 61 | try_count += 1 62 | continue 63 | 64 | try: 65 | valid = self.output_func(output_text) 66 | break 67 | except Exception as e: 68 | print("EvaluatorAgent process output error!") 69 | print(e) 70 | continue 71 | 72 | return valid 73 | 74 | 75 | class Pipeline: 76 | def __init__(self, paraphraser_agent, eval_agent, iters=1, retry=5): 77 | """ 78 | Initializes the Pipeline class with paraphraser and evaluator agents, post evaluation action, and number of iterations. 79 | :param paraphraser_agent: The agent responsible for rephrasing the data. 80 | :param eval_agent: The agent responsible for evaluating the paraphrased data. 81 | :param iters: The number of iterations to paraphrase and evaluate. 82 | """ 83 | self.paraphraser_agent = paraphraser_agent 84 | self.eval_agent = eval_agent 85 | self.iters = iters 86 | self.retry = retry 87 | 88 | def __call__(self, original_data): 89 | """ 90 | Processes the data through the pipeline, paraphrasing and evaluating according to the specified logic. 91 | :param data: The data to be paraphrased. 92 | :return: The paraphrased data list. 93 | """ 94 | data = copy.deepcopy(original_data) 95 | paraphrased_data_list = [] 96 | paraphrased_data_list.append(copy.deepcopy(original_data)) 97 | 98 | for _ in range(self.iters): 99 | paraphrased = False 100 | retry = self.retry 101 | while not paraphrased and retry > 0: 102 | copied_data = copy.deepcopy(data) 103 | retry -= 1 104 | self.paraphraser_agent(data) 105 | valid = self.eval_agent(original_data, data) 106 | 107 | if not valid: 108 | data = copied_data 109 | continue 110 | else: 111 | paraphrased = True 112 | 113 | if not paraphrased: 114 | print("Paraphraser failed to paraphrase the data!") 115 | paraphrased_data_list.append(copy.deepcopy(paraphrased_data_list[-1])) 116 | else: 117 | print("Paraphraser successfully paraphrased the data!") 118 | paraphrased_data_list.append(copy.deepcopy(data)) 119 | 120 | return paraphrased_data_list 121 | -------------------------------------------------------------------------------- /promptbench/mpa/dataprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import re 5 | import random 6 | from ..utils.dataprocess import InputProcess, OutputProcess 7 | 8 | class ParaphraserInputProcess: 9 | def __call__(self, prompt: str, data: dict) -> str: 10 | raise NotImplementedError 11 | 12 | 13 | class ParaphraserOutputProcess(): 14 | def __call__(self, text: str, data: dict) -> dict: 15 | raise NotImplementedError 16 | 17 | 18 | class EvaluatorInputProcess: 19 | def __call__(self, prompt: str, original_data: dict, paraphraserd_data: dict) -> str: 20 | raise NotImplementedError 21 | 22 | 23 | class EvaluatorOutputProcess(): 24 | def __call__(self, text: str) -> bool: 25 | raise NotImplementedError 26 | 27 | 28 | class ChoicePermuter: 29 | @staticmethod 30 | def permute(choices_str: str, correct_answer: str) -> tuple: 31 | # Use regular expression to split the string into individual choices 32 | choices = re.split(r'\n(?=[A-E]:)', choices_str) 33 | 34 | # Extract the label and content of each choice 35 | labels = [choice.split(':', 1)[0] for choice in choices] 36 | contents = [choice.split(':', 1)[1].strip() for choice in choices] 37 | 38 | # Shuffle the contents 39 | random.shuffle(contents) 40 | 41 | # Reassign the shuffled contents to the original labels 42 | permuted_choices = [f"{labels[i]}: {contents[i]}" for i in range(len(contents))] 43 | 44 | # Find the new label for the correct answer 45 | correct_content = choices[ord(correct_answer) - ord('A')].split(':', 1)[1].strip() 46 | new_correct_answer = labels[contents.index(correct_content)] 47 | 48 | return '\n'.join(permuted_choices), new_correct_answer 49 | 50 | 51 | class ParaphraserBasicInputProcess(ParaphraserInputProcess): 52 | def __call__(self, prompt: str, data: dict) -> str: 53 | return InputProcess.basic_format(prompt, data) 54 | 55 | 56 | class ParaphraserQuestionOutputProcess(ParaphraserOutputProcess): 57 | def __call__(self, text: str, data: dict) -> dict: 58 | output = re.findall("<<<(.*?)>>>", text, re.DOTALL) 59 | if len(output) == 1: 60 | data["question"] = output[0] 61 | return data 62 | else: 63 | raise ValueError("Invalid output format") 64 | 65 | 66 | class ParaphraserChoicesOutputProcess(ParaphraserOutputProcess): 67 | def __call__(self, text: str, data: dict) -> dict: 68 | output = re.findall("<<<(.*?)>>>", text, re.DOTALL) 69 | if len(output) == 1: 70 | choices = output[0] 71 | data["choices"] = choices 72 | return data 73 | else: 74 | raise ValueError("Invalid output format") 75 | 76 | 77 | class EvaluatorMMLUQuestionInputProcess(EvaluatorInputProcess): 78 | def __call__(self, prompt: str, original_data: dict, paraphraserd_data: dict) -> str: 79 | data = {} 80 | data["question"] = original_data["question"] 81 | data["paraphrased"] = paraphraserd_data["question"] 82 | return InputProcess.basic_format(prompt, data) 83 | 84 | 85 | class EvaluatorGSM8KQuestionInputProcess(EvaluatorInputProcess): 86 | def __call__(self, prompt: str, original_data: dict, paraphraserd_data: dict) -> str: 87 | data = {} 88 | data["question"] = original_data["question"] 89 | data["paraphrased"] = paraphraserd_data["question"] 90 | data["answer"] = original_data["answer"] 91 | return InputProcess.basic_format(prompt, data) 92 | 93 | 94 | class EvaluatorMMLUParaphrasedChoicesInputProcess(EvaluatorInputProcess): 95 | def __call__(self, prompt: str, original_data: dict, paraphraserd_data: dict) -> str: 96 | data = {} 97 | data["question"] = original_data["question"] 98 | data["choices"] = original_data["choices"] 99 | data["paraphrased"] = paraphraserd_data["choices"] 100 | data["answer"] = original_data["answer"] 101 | return InputProcess.basic_format(prompt, data) 102 | 103 | 104 | class EvaluatorMMLUNewChoiceInputProcess(EvaluatorInputProcess): 105 | def __call__(self, prompt: str, original_data: dict, paraphraserd_data: dict) -> str: 106 | data = {} 107 | data["question"] = original_data["question"] 108 | data["choices"] = original_data["choices"] 109 | data["new_choice"] = paraphraserd_data["choices"] 110 | data["answer"] = original_data["answer"] 111 | return InputProcess.basic_format(prompt, data) 112 | 113 | 114 | class EvaluatorBasicOutputProcess(EvaluatorOutputProcess): 115 | def __call__(self, text: str) -> bool: 116 | output = re.findall("<<<(.*?)>>>", text, re.DOTALL) 117 | if len(output) == 1: 118 | if output[0].lower() == "yes": 119 | return True 120 | elif output[0].lower() == "no": 121 | return False 122 | else: 123 | raise ValueError("Invalid output format") 124 | else: 125 | raise ValueError("Invalid output format") -------------------------------------------------------------------------------- /promptbench/prompt_attack/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .attack import * 5 | 6 | MNLI_LABEL = ['entailment', 'neutral', 'contradiction', 7 | 'entailment\'', 'neutral\'', 'contradiction\''] 8 | EQ_LABEL = ['equivalent', 'not_equivalent', 'equivalent\'', 'not_equivalent\''] 9 | ENTAIL_LABEL = ['entailment', 'not_entailment', 'entailment\'', 10 | 'not_entailment\'', '0', '1', '0\'', '1\''] 11 | 12 | LABEL_SET = { 13 | # 'positive\'', 'negative\'' is used for label constraint due to a bug of TextAttack repo. 14 | 'sst2': ['positive', 'negative', 'positive\'', 'negative\'', '0', '1', '0\'', '1\''], 15 | 'mnli': MNLI_LABEL, 16 | 'mnli_mismatched': MNLI_LABEL, 17 | 'mnli_matched': MNLI_LABEL, 18 | 'qqp': EQ_LABEL, 19 | 'qnli': ENTAIL_LABEL, 20 | 'rte': ENTAIL_LABEL, 21 | 'cola': ['unacceptable', 'acceptable', 'unacceptable\'', 'acceptable\''], 22 | 'mrpc': EQ_LABEL, 23 | 'wnli': ENTAIL_LABEL, 24 | 'mmlu': ['A', 'B', 'C', 'D', 'A\'', 'B\'', 'C\'', 'D\'', 'a', 'b', 'c', 'd', 'a\'', 'b\'', 'c\'', 'd\''], 25 | # do not change the word 'nothing' in prompts. 26 | 'squad_v2': ['unanswerable', 'unanswerable\''], 27 | 'iwslt': ['translate', 'translate\''], 28 | 'un_multi': ['translate', 'translate\''], 29 | 'math': ['math', 'math\''], 30 | 'bool_logic': ['True', 'False', 'True\'', 'False\'', "bool", "boolean", "bool\'", "boolean\'"], 31 | 'valid_parentheses': ['Valid', 'Invalid', 'Valid\'', 'Invalid\'', 'matched', 'matched\'', 'valid', 'invalid', 'valid\'', 'invalid\''], 32 | } 33 | 34 | attack_config = { 35 | "goal_function": { 36 | "query_budget": float("inf"), 37 | }, 38 | 39 | "textfooler": { 40 | "max_candidates": 50, 41 | "min_word_cos_sim": 0.6, 42 | "min_sentence_cos_sim": 0.840845057, 43 | }, 44 | 45 | "textbugger": { 46 | "max_candidates": 5, 47 | "min_sentence_cos_sim": 0.8, 48 | }, 49 | 50 | "deepwordbug": { 51 | "levenshtein_edit_distance" : 30, 52 | }, 53 | 54 | "bertattack": { 55 | "max_candidates": 48, 56 | "max_word_perturbed_percent": 1, 57 | "min_sentence_cos_sim": 0.8, 58 | }, 59 | 60 | "checklist": { 61 | "max_candidates": 5, 62 | }, 63 | 64 | "stresstest": { 65 | "max_candidates": 5, 66 | } 67 | 68 | } 69 | 70 | 71 | -------------------------------------------------------------------------------- /promptbench/prompt_attack/label_constraint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from textattack.constraints import PreTransformationConstraint 5 | 6 | class LabelConstraint(PreTransformationConstraint): 7 | """ 8 | A constraint that does not allow to attack the labels (or any words that is important for tasks) in the prompt. 9 | """ 10 | 11 | def __init__(self, labels=[]): 12 | self.labels = [label.lower() for label in labels] 13 | 14 | def _get_modifiable_indices(self, current_text): 15 | modifiable_indices = set() 16 | modifiable_words = [] 17 | for i, word in enumerate(current_text.words): 18 | if str(word).lower() not in self.labels: 19 | modifiable_words.append(word) 20 | modifiable_indices.add(i) 21 | print("--------------------------------------------------") 22 | print("Modifiable words: ", modifiable_words) 23 | print("--------------------------------------------------\n") 24 | return modifiable_indices 25 | 26 | def check_compatibility(self, transformation): 27 | """ 28 | It is always true. 29 | """ 30 | return True -------------------------------------------------------------------------------- /promptbench/prompt_attack/search.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from textattack.search_methods import SearchMethod 5 | 6 | class BruteForceSearch(SearchMethod): 7 | def perform_search(self, initial_result): 8 | text = initial_result.attacked_text 9 | transformed_text_candidates = self.get_transformations(text, original_text=text) 10 | results, _ = self.get_goal_results(transformed_text_candidates) 11 | results = sorted(results, key=lambda x: -x.score) 12 | 13 | return results[0] 14 | 15 | @property 16 | def is_black_box(self): 17 | return True -------------------------------------------------------------------------------- /promptbench/prompt_attack/transformations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from textattack.transformations import Transformation 5 | 6 | class CheckListTransformation(Transformation): 7 | 8 | def generate_random_sequences(num, len): 9 | seqs = [] 10 | import random 11 | import string 12 | 13 | for _ in range(num): 14 | seq = ''.join(random.choices(string.ascii_uppercase + string.ascii_lowercase + string.digits, k=len)) 15 | seqs.append(seq) 16 | 17 | return seqs 18 | 19 | def _get_transformations(self, current_text, indices_to_modify): 20 | 21 | # rand_seqs = self.generate_random_sequences(50, 10) 22 | 23 | rand_seqs = ['d6ZQ3u0GBQ', 'vTAjHynoIG', 'OB4KVJzIft', 'LkF0FZxMZ4', 'iia2yL9Uzm', 'CuVpbbkC4c', 24 | 'w52rwgo0Av', 'Vq3aBzuZcD', 'hXLpw3bbiw', 'RcRneWo6Iv', 'S6oUV5E54P', 'xikCjkMydH', 25 | 'MQnugHcaoy', 'Q47Jmd4lMV', '9vGXgnbWB8', 'IhuBIhoPGc', '5yWbBXztUY', 'AMsRIKZniY', 26 | 'EAB4KP2NVY', '9Q3S1F94fE', 'b74X5IVXQY', 'SFPCqpiClT', 'bhrRSokrfa', 'YHQiDfFhw4', 27 | 'BWmsLx8xOA', 'PDCGfz6DL9', 'yh912BU8T0', 'ofOQXLxiW4', 'Cev0s49fbe', 'rzu98kF2KF', 28 | 'zexKUTWJYG', '5XeflW1ZJc', 'is77sOXAu8', 'XStsD2pUzu', 'fwhUZUQzXW', 'Pbl3tYuoRo', 29 | 'MSAddJ4D2a', 'mzjVogkRhM', 'Kw6nrs57gH', 'ItGDrrA1Sc', 'KjPJJ2a7RB', 'mOJ9nEwT4f', 30 | 'ofw9fEkN5R', 'njCuciQJpB', '6a4Yn3RGVc', 'SvAp8RlOFn', 'g0vBZf3tQC', 'zq0DcZ5dnI', 31 | 'lf8wBa2yEm', 'lWJoGGRdjv'] 32 | 33 | transformed_texts = [] 34 | for rand_seq in rand_seqs: 35 | transformed_texts.append(current_text.insert_text_after_word_index(index=len(current_text.words)-1, text=rand_seq)) 36 | return transformed_texts 37 | 38 | 39 | class StressTestTransformation(Transformation): 40 | def _get_transformations(self, current_text, indices_to_modify): 41 | texts = [" and true is true ", " and false is not true ", " and true is true "*5] 42 | transformed_texts = [] 43 | for text in texts: 44 | transformed_texts.append(current_text.insert_text_after_word_index(index=len(current_text.words)-1, text=text)) 45 | # transformed_texts.append(current_text.insert_text_after_word_index(index=0, text=text)) 46 | 47 | return transformed_texts -------------------------------------------------------------------------------- /promptbench/prompt_engineering/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from tqdm import tqdm 5 | import re 6 | 7 | from .base import Base 8 | from .least_to_most import LeastToMost 9 | from .generated_knowledge import GeneratedKnowledge 10 | from .chain_of_thought import ZSCoT, CoT 11 | from .expert_prompting import ExpertPrompting 12 | from .emotion_prompt import EmotionPrompt 13 | 14 | from ..metrics import Eval 15 | 16 | SUPPORTED_METHODS = ['CoT', 'ZSCoT', 'least_to_most', 'generated_knowledge', 'expert_prompting', 'emotion_prompt', 'baseline'] 17 | 18 | METHOD_MAP = { 19 | 'CoT': CoT, 20 | 'ZSCoT': ZSCoT, 21 | 'least_to_most': LeastToMost, 22 | 'generated_knowledge': GeneratedKnowledge, 23 | 'expert_prompting': ExpertPrompting, 24 | 'emotion_prompt': EmotionPrompt, 25 | 'baseline': Base, 26 | } 27 | 28 | METHOD_SUPPORT_DATASET = { 29 | 'CoT': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking'], 30 | 'ZSCoT': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking'], 31 | 'expert_prompting': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking'], 32 | 'emotion_prompt': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking'], 33 | 'least_to_most': ['gsm8k', 'last_letter_concat'], 34 | 'generated_knowledge': ['csqa', 'numersense', 'qasc'], 35 | 'baseline': ['gsm8k', 'csqa', 'bigbench_date', 'bigbench_object_tracking', 'last_letter_concat', 'numersense', 'qasc'], 36 | } 37 | 38 | # Model: GPT3.5, GPT4, Llama7b-chat, Llama13b-chat, llama70b-chat 39 | 40 | class PEMethod(object): 41 | """ 42 | A class that provides an interface for various methods in prompt engineering. 43 | It supports method creation, and inference based on method name. 44 | """ 45 | 46 | def __init__(self, **kwargs): 47 | self.method = kwargs.get('method') 48 | self.infer_method = self.create_method(**kwargs) 49 | 50 | def create_method(self, **kwargs): 51 | """Creates and returns the appropriate method based on the method name.""" 52 | 53 | # Get the method class based on the method name and instantiate it 54 | method_class = METHOD_MAP.get(self.method) 55 | if method_class: 56 | return method_class(**kwargs) 57 | else: 58 | raise ValueError("The method is not supported!") 59 | 60 | @staticmethod 61 | def method_list(): 62 | """Returns a list of supported methods.""" 63 | return METHOD_MAP.keys() 64 | 65 | def test(self, dataset, model, num_samples=None): 66 | """Tests the method on the given dataset and returns the accuracy.""""" 67 | preds = [] 68 | labels = [] 69 | for i, data in enumerate(tqdm(dataset)): 70 | if num_samples and i >= num_samples: 71 | break 72 | 73 | label = data['label'] 74 | labels.append(label) 75 | 76 | input_text = data['content'] 77 | ouput = self.infer_method.query(input_text, model) 78 | res = re.findall(r'##(.*)', ouput) 79 | pred = res[0] if res else ouput 80 | pred = dataset.extract_answer(pred) #FIXME 执行取片操作后丢失类 81 | preds.append(pred) 82 | 83 | score = Eval.compute_cls_accuracy(preds, labels) 84 | return score 85 | 86 | def __call__(self, input_text, model): 87 | """Calls the method to perform inference.""" 88 | return self.infer_method.query(input_text, model) 89 | -------------------------------------------------------------------------------- /promptbench/prompt_engineering/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | class Base: 5 | def __init__(self, **kwargs): 6 | self.dataset_name = kwargs.get('dataset') 7 | self.verbose = kwargs.get('verbose', False) 8 | 9 | if self.dataset_name == "gsm8k": 10 | self.output_range = "arabic numerals" 11 | elif self.dataset_name == "csqa": 12 | self.output_range = "among A through E" 13 | elif self.dataset_name == "bigbench_date": 14 | self.output_range = "among A through F" 15 | elif self.dataset_name == "bigbench_object_tracking": 16 | self.output_range = "among A through C" 17 | elif self.dataset_name == "qasc": 18 | self.output_range = "among A through H" 19 | elif self.dataset_name == "numersense": 20 | self.output_range = "numbers expressed in English words, e.g. 'one', 'two', 'three', ..." 21 | elif self.dataset_name == "last_letter_concat": 22 | self.output_range = "English letter combinations, e.g. 'afsa', 'abgsa', ..." 23 | else: 24 | self.output_range = "No format restrictions" 25 | 26 | 27 | def query(self, input_text, model): 28 | instr_get_answer = input_text + '\n' + \ 29 | f'Please output your answer at the end as ##' 30 | prompt_get_answer = model.convert_text_to_prompt(instr_get_answer, 'user') 31 | 32 | answer = model(prompt_get_answer) 33 | 34 | # print(instr_get_answer) 35 | # print(answer) 36 | return answer 37 | -------------------------------------------------------------------------------- /promptbench/prompt_engineering/chain_of_thought.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | from .base import Base 7 | from ..prompts.method_oriented import get_prompt 8 | 9 | class BaseCoT(Base): 10 | """ 11 | A base class for implementing Chain of Thought (CoT) reasoning in models. 12 | 13 | This class serves as a foundational component for models that employ the CoT approach to process and answer queries. It sets up the basic structure and methods that are common across different CoT implementations. 14 | 15 | Attributes: 16 | ----------- 17 | cot_trigger : str 18 | A string prompt that activates the Chain of Thought reasoning process in the model. 19 | 20 | Methods: 21 | -------- 22 | __init__(**kwargs) 23 | Initializes the BaseCoT instance with specified keyword arguments. 24 | query(input_text, model) 25 | An abstract method to be implemented by subclasses, defining how a query is processed and answered by the model. 26 | """ 27 | def __init__(self, **kwargs): 28 | super().__init__(**kwargs) 29 | self.cot_trigger= get_prompt(['chain_of_thought', 'cot_trigger']) 30 | 31 | @abstractmethod 32 | def query(self, input_text, model): 33 | pass 34 | 35 | 36 | class ZSCoT(BaseCoT): 37 | """ 38 | A class for implementing Zero-Shot Chain of Thought (ZSCoT) reasoning. 39 | 40 | This class is designed for situations where no prior examples (zero-shot) are provided to the model. It utilizes the base CoT approach and extends it to work in a zero-shot learning environment. 41 | 42 | Methods: 43 | -------- 44 | __init__(**kwargs) 45 | Initializes the ZSCoT instance with specified keyword arguments. 46 | query(input_text, model) 47 | Processes the input text and uses the model to generate a response using zero-shot CoT reasoning. The method constructs a prompt sequence, queries the model, and returns the model's answer. 48 | 49 | Paper Link: https://arxiv.org/pdf/2205.11916.pdf 50 | """ 51 | def __init__(self, **kwargs): 52 | super().__init__(**kwargs) 53 | 54 | def query(self, input_text, model): 55 | prompt_question = model.convert_text_to_prompt(input_text, 'user') 56 | 57 | instr_get_answer = self.cot_trigger + '\n' + \ 58 | f'Please output your answer at the end as ##' 59 | prompt_get_answer = model.convert_text_to_prompt(instr_get_answer, 'assistant') 60 | 61 | prompt_get_answer = model.concat_prompts([prompt_question, prompt_get_answer]) 62 | 63 | answer = model(prompt_get_answer) 64 | 65 | if self.verbose: 66 | print(prompt_get_answer) 67 | print(answer) 68 | 69 | return answer 70 | 71 | 72 | class CoT(BaseCoT): 73 | """ 74 | Paper 75 | A class for implementing Chain of Thought (CoT) reasoning with few-shot examples. 76 | 77 | This class enhances the base CoT approach by incorporating few-shot learning, where a small number of example cases are used to guide the model's reasoning process. 78 | 79 | Attributes: 80 | ----------- 81 | few_shot_examples : str 82 | A string containing few-shot examples relevant to the dataset_name, aiding the model in understanding and responding to queries. 83 | 84 | Methods: 85 | -------- 86 | __init__(**kwargs) 87 | Initializes the CoT instance with specified keyword arguments and loads few-shot examples. 88 | query(input_text, model) 89 | Processes the input text and uses the model to generate a response. The method constructs a sequence of prompts, including few-shot examples, to guide the model's reasoning before querying it for an answer. 90 | 91 | Paper Link: https://arxiv.org/pdf/2201.11903.pdf 92 | """ 93 | def __init__(self, **kwargs): 94 | super().__init__(**kwargs) 95 | self.few_shot_examples = get_prompt(['chain_of_thought', self.dataset_name]) 96 | 97 | def query(self, input_text, model): 98 | instr_question = self.few_shot_examples + '\n'+ 'Q: ' + input_text + '\n' + 'A:' 99 | prompt_question = model.convert_text_to_prompt(instr_question, 'user') 100 | instr_get_answer = self.cot_trigger + '\n' + \ 101 | f'Please output your answer at the end as ##' 102 | prompt_get_answer = model.convert_text_to_prompt(instr_get_answer, 'assistant') 103 | prompt_get_answer = model.concat_prompts([prompt_question, prompt_get_answer]) 104 | 105 | answer = model(prompt_get_answer) 106 | 107 | if self.verbose: 108 | print(prompt_get_answer) 109 | print(answer) 110 | 111 | return answer 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /promptbench/prompt_engineering/emotion_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .base import Base 5 | from ..prompts.method_oriented import get_prompt 6 | 7 | 8 | class EmotionPrompt(Base): 9 | """ 10 | A class for handling emotion-based prompts in a conversational AI model. 11 | 12 | This class is specialized in generating responses for prompts that add some emotions. It leverages a predefined emotion prompt, tailored to the specific needs of the query, to guide the model's response generation. 13 | 14 | Attributes: 15 | ----------- 16 | prompt_id : int 17 | An identifier for selecting a specific emotion prompt from a collection of predefined prompts. 18 | emotion_prompt : str 19 | The actual emotion-based prompt text, retrieved based on the prompt_id, used to assist the model in generating emotion-aware responses. 20 | 21 | Methods: 22 | -------- 23 | __init__(**kwargs) 24 | Initializes the EmotionPrompt instance with specific keyword arguments, including setting the prompt_id and retrieving the corresponding emotion prompt. 25 | For example: "This is very important to my career." 26 | query(input_text, model) 27 | Processes the input text by appending it with the emotion prompt. It then instructs the model to generate a response that is cognizant of the emotional context provided. The response is formatted with a specific answer notation. 28 | 29 | Paper Link: https://arxiv.org/pdf/2307.11760v3.pdf 30 | """ 31 | def __init__(self, **kwargs): 32 | super().__init__(**kwargs) 33 | 34 | self.prompt_id = int(kwargs.get('prompt_id')) 35 | 36 | self.emotion_prompts = get_prompt(['emotion_prompt', 'prompts']) 37 | self.emotion_prompt = self.emotion_prompts[self.prompt_id] 38 | 39 | def query(self, input_text, model): 40 | instr_get_answer = input_text + '\n' + self.emotion_prompt + '\n' + \ 41 | f'Please output your answer at the end as ##' 42 | prompt_get_answer = model.convert_text_to_prompt(instr_get_answer, 'user') 43 | 44 | answer = model(prompt_get_answer) 45 | 46 | if self.verbose: 47 | print(prompt_get_answer) 48 | print(answer) 49 | 50 | return answer -------------------------------------------------------------------------------- /promptbench/prompt_engineering/expert_prompting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .base import Base 5 | from ..prompts.method_oriented import get_prompt 6 | 7 | 8 | class ExpertPrompting(Base): 9 | """ 10 | A class designed for generating and utilizing expert-level prompts in AI models. 11 | 12 | This class focuses on creating advanced, expert-level prompts based on input text. It uses few-shot examples to guide the model in generating contextually rich and specialized prompts. These prompts are then used to extract expert-level answers from the model. 13 | 14 | Attributes: 15 | ----------- 16 | few_shot_examples : str 17 | Predefined expert-level few-shot examples that aid in guiding the model to generate context-specific, expert prompts. 18 | 19 | Methods: 20 | -------- 21 | __init__(**kwargs) 22 | Initializes the ExpertPrompting instance, setting up few-shot examples for expert prompting. 23 | generate_expert_prompt(input_text, model) 24 | Generates an expert prompt based on the input text. This method appends few-shot examples to the input text to create a prompt that guides the model in generating an expert-level response. 25 | query(input_text, model) 26 | Processes the input text by first generating an expert prompt and then using this prompt to guide the model's response. The method formats the model's response with a specific answer notation, ensuring clarity and precision in the answer provided. 27 | 28 | Paper Link: https://arxiv.org/pdf/2305.14688.pdf 29 | """ 30 | def __init__(self, **kwargs): 31 | super().__init__(**kwargs) 32 | 33 | self.few_shot_examples = get_prompt(['expert_prompt']) 34 | 35 | def generate_expert_prompt(self, input_text, model): 36 | instr_gen_expert_prompt = self.few_shot_examples + '\n' + '[Instruction]: ' + input_text + '\n' + '[Agent Description]:' 37 | prompt_gen_expert_prompt = model.convert_text_to_prompt(instr_gen_expert_prompt, 'user') 38 | 39 | expert_prompt = model(prompt_gen_expert_prompt) 40 | 41 | return expert_prompt 42 | 43 | def query(self, input_text, model): 44 | expert_prompt = self.generate_expert_prompt(input_text, model) 45 | 46 | instr_get_answer = expert_prompt + '\n\n' + \ 47 | 'Now given the above identity background, please answer the following instruction:' + \ 48 | '\n\n' + input_text + '\n' + f'Please output your answer at the end as ##' 49 | prompt_get_answer = model.convert_text_to_prompt(instr_get_answer, 'user') 50 | 51 | answer = model(prompt_get_answer) 52 | 53 | if self.verbose: 54 | print(prompt_get_answer) 55 | print(answer) 56 | 57 | return answer 58 | -------------------------------------------------------------------------------- /promptbench/prompt_engineering/generated_knowledge.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from ..prompts.method_oriented import get_prompt 5 | from .base import Base 6 | 7 | 8 | class GeneratedKnowledge(Base): 9 | """ 10 | A class for generating and utilizing synthesized knowledge in AI model queries. 11 | 12 | This class specializes in creating synthesized knowledge based on input text and few-shot examples. It generates context-specific knowledge snippets, which are then used to enhance the model's ability to provide informed and relevant answers to queries. 13 | 14 | Attributes: 15 | ----------- 16 | few_shot_examples : str 17 | Custom few-shot examples specific to the dataset, used to guide the model in generating relevant synthesized knowledge. 18 | 19 | Methods: 20 | -------- 21 | __init__(**kwargs) 22 | Initializes the GeneratedKnowledge instance, loading few-shot examples tailored to the dataset name provided in the keyword arguments. 23 | generate_knowledge(input_text, model) 24 | Generates synthesized knowledge based on the input text. This method utilizes few-shot examples and modifies them with the input question to create a prompt that instructs the model to generate knowledge snippets. 25 | query(input_text, model) 26 | Processes the input text by first generating synthesized knowledge and then using this knowledge to enhance the query's context. The method then prompts the model to generate an answer, ensuring that the response is informed by the newly generated knowledge. The final answer is formatted with a specific notation for clarity. 27 | 28 | Paper Link: https://arxiv.org/pdf/2110.08387.pdf 29 | """ 30 | 31 | def __init__(self, **kwargs): 32 | super().__init__(**kwargs) 33 | 34 | self.few_shot_examples = get_prompt(['generated_knowledge', self.dataset_name]) 35 | 36 | def generate_knowledge(self, input_text, model): 37 | instr_gen_knowledge = self.few_shot_examples.replace('{question}', input_text) 38 | 39 | prompt_gen_knowledge = model.convert_text_to_prompt( 40 | instr_gen_knowledge, 'user' 41 | ) 42 | 43 | # TODO: receive more settings for query 44 | knowledges = model(prompt_gen_knowledge, temperature=1.0, n=1, max_tokens=60) 45 | 46 | knowledges = list(set([_ for _ in knowledges if _ != ''])) 47 | knowledge = '\n'.join(knowledges) 48 | 49 | if self.verbose: 50 | print(prompt_gen_knowledge) 51 | print(knowledge) 52 | 53 | return knowledge 54 | 55 | def query(self, input_text, model): 56 | if "Answer Choices" in input_text: 57 | raw_question = input_text.split("Answer Choices", 1)[0] 58 | else: 59 | raw_question = input_text 60 | 61 | knowledge = self.generate_knowledge(raw_question, model) 62 | 63 | instr_get_answer = knowledge + '\n\n' + input_text + '\n' + \ 64 | f'Please output your answer at the end as ##' 65 | prompt_get_answer = model.convert_text_to_prompt( 66 | instr_get_answer, 'user' 67 | ) 68 | answer = model(prompt_get_answer) 69 | 70 | if self.verbose: 71 | print(prompt_get_answer) 72 | print(answer) 73 | 74 | return answer -------------------------------------------------------------------------------- /promptbench/prompt_engineering/least_to_most.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from ..prompts.method_oriented import get_prompt 5 | from .base import Base 6 | 7 | class LeastToMost(Base): 8 | """ 9 | A class designed to implement the 'least-to-most' prompting strategy in AI models. 10 | 11 | This class utilizes a step-by-step approach, breaking down complex problems into simpler sub-problems. It leverages few-shot examples tailored to specific datasets to guide the model in sequentially addressing each part of a problem, ultimately leading to a comprehensive solution. 12 | 13 | Attributes: 14 | ----------- 15 | few_shot : str 16 | Custom few-shot examples specific to the dataset name provided in the keyword arguments. These examples serve as templates to guide the model in breaking down and solving problems step-by-step. 17 | 18 | Methods: 19 | -------- 20 | __init__(**kwargs) 21 | Initializes the LeastToMost instance, setting up few-shot examples based on the provided dataset name. 22 | query(input_text, model) 23 | Processes the input text using the 'least-to-most' approach. It first breaks down the problem into sub-problems using the few-shot examples and then sequentially solves each sub-problem. The method combines these steps to construct a comprehensive answer to the original problem. The final answer is formatted with a specific notation for clarity and consistency. 24 | 25 | Paper Link: https://arxiv.org/pdf/2205.10625.pdf 26 | """ 27 | def __init__(self, **kwargs): 28 | super().__init__(**kwargs) 29 | 30 | self.few_shot = get_prompt(['least_to_most', self.dataset_name]) 31 | 32 | def query(self, input_text, model): 33 | instr_breakdown = self.few_shot + '\n\n' + 'Q: ' + input_text + '\n' + 'A: Let’s break down this problem, then solve it one by one.' 34 | prompt_breakdown = model.convert_text_to_prompt(instr_breakdown, 'user') 35 | 36 | # break down problem to sub-problems and solve them one by one 37 | ans_subproblem = model(prompt_breakdown) 38 | answer_trigger = f'Please output your answer at the end as ##' 39 | 40 | prompt_get_answer = model.concat_prompts([prompt_breakdown, 41 | model.convert_text_to_prompt(ans_subproblem, 'assistant'), 42 | model.convert_text_to_prompt(answer_trigger, 'user') 43 | ]) 44 | answer = model(prompt_get_answer) 45 | 46 | if self.verbose: 47 | print(prompt_get_answer) 48 | print(answer) 49 | 50 | return answer 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /promptbench/prompteval/__init__.py: -------------------------------------------------------------------------------- 1 | from .efficient_eval import * -------------------------------------------------------------------------------- /promptbench/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .prompt import Prompt -------------------------------------------------------------------------------- /promptbench/prompts/adv_prompts/Readme.md: -------------------------------------------------------------------------------- 1 | The following are raw md files that collect the adversarial prompts for each model. 2 | 3 | Please visit https://huggingface.co/spaces/March07/PromptBench for a more user-friendly experience: 4 | 5 | ![](https://wjdcloud.blob.core.windows.net/tools/fig-streamlit.png) 6 | -------------------------------------------------------------------------------- /promptbench/prompts/prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import re 5 | import os 6 | from typing import Union, List 7 | 8 | 9 | class Prompt: 10 | def __init__(self, prompt_content: Union[List[str], str, None] = None, dataset_name: str = None): 11 | 12 | if prompt_content: 13 | if isinstance(prompt_content, list): 14 | self.prompts = prompt_content 15 | else: 16 | self.prompts = [prompt_content] 17 | elif dataset_name: 18 | self.prompts = self._load_default_prompt(dataset_name) 19 | self.dataset_name = dataset_name 20 | else: 21 | raise ValueError("Either provide prompt_content or specify dataset_name, task_type, and shot_type.") 22 | 23 | def __len__(self): 24 | return len(self.prompts) 25 | 26 | def __getitem__(self, index): 27 | return self.prompts[index] 28 | 29 | def _load_default_prompt(self, dataset_name): 30 | from .task_oriented import TASK_ORIENTED_PROMPTS 31 | from .role_oriented import ROLE_ORIENTED_PROMPTS 32 | return TASK_ORIENTED_PROMPTS[dataset_name] + ROLE_ORIENTED_PROMPTS[dataset_name] 33 | 34 | def load_adv_prompt(self, model_name, dataset_name, attack_name, prompt_type): 35 | prompts_dict = retrieve(model_name, dataset_name, attack_name, prompt_type) 36 | self.prompts = [adv_prompt["attack prompt"] for adv_prompt in prompts_dict] 37 | 38 | def add_few_shot_examples(self, few_shot_examples=None): 39 | if few_shot_examples != None: 40 | few_shot_examples = self._load_default_few_shot_examples() 41 | 42 | self.prompts = [prompt + "\n" + few_shot_examples for prompt in self.prompts] 43 | 44 | def _load_default_few_shot_examples(self): 45 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 46 | few_shot_path = os.path.join(cur_dir, "few_shot_examples.yaml") 47 | import yaml 48 | few_shot_examples = yaml.load(open(few_shot_path, "r"), Loader=yaml.FullLoader) 49 | 50 | return few_shot_examples[self.dataset_name] 51 | 52 | 53 | """ 54 | The following are help functions for retrieving adversarial prompts from the markdown file. 55 | """ 56 | def split_markdown_by_title(markdown_file): 57 | with open(markdown_file, 'r', encoding='utf-8') as f: 58 | content = f.read() 59 | 60 | re_str = "# cola|# mnli|# mrpc|# qnli|# qqp|# rte|# sst2|# wnli|# mmlu|# squad_v2|# iwslt|# un_multi|# math" 61 | 62 | datasets = ["# cola", "# mnli", "# mrpc", "# qnli", "# qqp", "# rte", "# sst2", "# wnli", 63 | "# mmlu", "# squad_v2", "# iwslt", "# un_multi", "# math"] 64 | 65 | primary_sections = re.split(re_str, content)[1:] 66 | assert len(primary_sections) == len(datasets) 67 | 68 | all_sections_dict = {} 69 | 70 | for dataset, primary_section in zip(datasets, primary_sections): 71 | re_str = "## " 72 | results = re.split(re_str, primary_section) 73 | keywords = ["10 prompts", "bertattack", "checklist", "deepwordbug", "stresstest", 74 | "textfooler", "textbugger", "translation"] 75 | 76 | secondary_sections_dict = {} 77 | for res in results: 78 | for keyword in keywords: 79 | if keyword in res.lower(): 80 | secondary_sections_dict[keyword] = res 81 | break 82 | 83 | all_sections_dict[dataset] = secondary_sections_dict 84 | 85 | return all_sections_dict 86 | 87 | 88 | def list_files(directory): 89 | files = [os.path.join(directory, d) for d in os.listdir(directory) if not os.path.isdir(os.path.join(directory, d))] 90 | return files 91 | 92 | 93 | def convert_model_name(attack): 94 | attack_name = { 95 | "google/flan-t5-large": "t5", 96 | "google/flan-ul2": "ul2", 97 | "vicuna-13b": "vicuna", 98 | "llama2-13b-chat": "llama2", 99 | "chatgpt": "chatgpt", 100 | } 101 | return attack_name[attack] 102 | 103 | 104 | def convert_dataset_name(dataset): 105 | return "# " + dataset 106 | 107 | 108 | def retrieve(model_name, dataset_name, attack_name, prompt_type): 109 | model_name = convert_model_name(model_name) 110 | dataset_name = convert_dataset_name(dataset_name) 111 | 112 | if "zero" in prompt_type: 113 | shot = "zeroshot" 114 | else: 115 | shot = "fewshot" 116 | 117 | if "task" in prompt_type: 118 | prompt_type = "task" 119 | else: 120 | prompt_type = "role" 121 | 122 | directory_path = "./adv_prompts" 123 | md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md") 124 | sections_dict = split_markdown_by_title(md_dir) 125 | results = {} 126 | for cur_dataset in sections_dict.keys(): 127 | if cur_dataset == dataset_name: 128 | dataset_dict = sections_dict[cur_dataset] 129 | best_acc = 0 130 | best_prompt = "" 131 | for cur_attack in dataset_dict.keys(): 132 | if cur_attack == "10 prompts": 133 | prompts_dict = dataset_dict[cur_attack].split("\n") 134 | num = 0 135 | for prompt_summary in prompts_dict: 136 | if "Acc: " not in prompt_summary: 137 | continue 138 | else: 139 | import re 140 | num += 1 141 | match = re.search(r'Acc: (\d+\.\d+)%', prompt_summary) 142 | if match: 143 | number = float(match.group(1)) 144 | if number > best_acc: 145 | best_acc = number 146 | best_prompt = prompt_summary.split("prompt: ")[1] 147 | 148 | for cur_attack in dataset_dict.keys(): 149 | 150 | if cur_attack == attack_name: 151 | 152 | if attack_name == "translation": 153 | prompts_dict = dataset_dict[attack_name].split("\n") 154 | 155 | for prompt_summary in prompts_dict: 156 | if "acc: " not in prompt_summary: 157 | continue 158 | 159 | prompt = prompt_summary.split("prompt: ")[1] 160 | 161 | import re 162 | 163 | match_atk = re.search(r'acc: (\d+\.\d+)%', prompt_summary) 164 | number_atk = float(match_atk.group(1)) 165 | results[prompt] = number_atk 166 | 167 | sorted_results = sorted(results.items(), key=lambda item: item[1])[:6] 168 | 169 | returned_results = [] 170 | for result in sorted_results: 171 | returned_results.append({"origin prompt": best_prompt, "origin acc": best_acc, "attack prompt": result[0], "attack acc": result[1]}) 172 | 173 | return returned_results 174 | 175 | elif attack_name in ["bertattack", "checklist", "deepwordbug", "stresstest", "textfooler", "textbugger"]: 176 | 177 | prompts_dict = dataset_dict[attack_name].split("Original prompt: ") 178 | num = 0 179 | 180 | returned_results = [] 181 | for prompt_summary in prompts_dict: 182 | if "Attacked prompt: " not in prompt_summary: 183 | continue 184 | 185 | origin_prompt = prompt_summary.split("\n")[0] 186 | attack_prompt = prompt_summary.split("Attacked prompt: ")[1].split("Original acc: ")[0] 187 | attack_prompt = bytes(attack_prompt[2:-1], "utf-8").decode("unicode_escape").encode("latin1").decode("utf-8") 188 | 189 | print(origin_prompt) 190 | print(attack_prompt) 191 | 192 | num += 1 193 | import re 194 | match_origin = re.search(r'Original acc: (\d+\.\d+)%', prompt_summary) 195 | match_atk = re.search(r'attacked acc: (\d+\.\d+)%', prompt_summary) 196 | if match_origin and match_atk: 197 | if prompt_type == "task": 198 | if num > 3: 199 | break 200 | else: 201 | if num < 3: 202 | continue 203 | number_origin = float(match_origin.group(1)) 204 | number_atk = float(match_atk.group(1)) 205 | returned_results.append({"origin prompt": origin_prompt, "origin acc": number_origin, "attack prompt": attack_prompt, "attack acc": number_atk}) 206 | 207 | return returned_results 208 | 209 | -------------------------------------------------------------------------------- /promptbench/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .dataprocess import InputProcess, OutputProcess 5 | from .visualize import Visualizer 6 | from .defense import Defense -------------------------------------------------------------------------------- /promptbench/utils/dataprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | class InputProcess: 5 | """ 6 | A utility class for processing input data for language models. 7 | 8 | This class provides static methods to format input data based on given prompt templates and input data dictionaries. 9 | 10 | Methods: 11 | -------- 12 | basic_format(prompt_template, input_data_dict) 13 | Combines a prompt template and input data to create a formatted model input. 14 | """ 15 | @staticmethod 16 | def basic_format(prompt_template, input_data_dict): 17 | """ 18 | Combine the prompt and input to create an input for the model. 19 | 20 | Parameters: 21 | - prompt_template (str): The template for the prompt with placeholders. 22 | - input_data_dict (dict): Dictionary containing data to fill in the template. 23 | 24 | Returns: 25 | - str: The combined model input. 26 | """ 27 | return prompt_template.format(**input_data_dict) 28 | 29 | # Additional input processing methods can be added here. 30 | # ... 31 | 32 | class OutputProcess: 33 | """ 34 | A utility class for processing raw predictions from language models. 35 | 36 | This class provides static methods for various ways to process and clean up raw prediction text. 37 | 38 | Methods: 39 | -------- 40 | general(raw_pred, proj_func=None) 41 | Performs general processing on the raw prediction text. 42 | cls(raw_pred, proj_func=None) 43 | Processes the raw prediction text for classification tasks. 44 | pattern_split(raw_pred, pattern, proj_func=None) 45 | Splits the raw prediction text based on a pattern. 46 | pattern_re(raw_pred, pattern, proj_func=None) 47 | Uses regular expressions to process the raw prediction text. 48 | """ 49 | 50 | @staticmethod 51 | def _base_pred_process(pred): 52 | """ 53 | Basic processing for predictions which involves lowercasing, 54 | removing special tokens and stripping unwanted characters. 55 | 56 | Parameters: 57 | - pred (str): The raw prediction text. 58 | 59 | Returns: 60 | - str: The processed prediction text. 61 | """ 62 | pred = pred.lower().replace("", "").replace("", "").strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") 63 | return pred 64 | 65 | @staticmethod 66 | def general(raw_pred, proj_func=None): 67 | """ 68 | General processing for predictions using the base prediction process. 69 | 70 | Parameters: 71 | - raw_pred (str): The raw prediction text. 72 | 73 | Returns: 74 | - str: The processed prediction text. 75 | """ 76 | pred = OutputProcess._base_pred_process(raw_pred) 77 | if proj_func: 78 | pred = proj_func(pred) 79 | return pred 80 | 81 | @staticmethod 82 | def cls(raw_pred, proj_func=None): 83 | """ 84 | Processes the prediction by taking the last word after basic processing. 85 | 86 | Parameters: 87 | - raw_pred (str): The raw prediction text. 88 | 89 | Returns: 90 | - str: The last word from the processed prediction text. 91 | """ 92 | pred = OutputProcess._base_pred_process(raw_pred).split(" ")[-1] 93 | if proj_func: 94 | pred = proj_func(pred) 95 | return pred 96 | 97 | @staticmethod 98 | def pattern_split(raw_pred, pattern, proj_func=None): 99 | """ 100 | Processes the prediction by splitting it based on a provided pattern 101 | and taking the last part. 102 | 103 | Parameters: 104 | - raw_pred (str): The raw prediction text. 105 | - pattern (str): The pattern to split the prediction text on. 106 | 107 | Returns: 108 | - str: The last part of the prediction text after splitting. 109 | """ 110 | pred = OutputProcess._base_pred_process(raw_pred.split(pattern)[-1]) 111 | if proj_func: 112 | pred = proj_func(pred) 113 | 114 | return pred 115 | 116 | @staticmethod 117 | def pattern_re(raw_pred, pattern, proj_func=None): 118 | """ 119 | Processes the prediction using regular expressions to extract a specific pattern. 120 | 121 | Parameters: 122 | - raw_pred (str): The raw prediction text. 123 | - pattern (str): The regular expression pattern to search for. 124 | 125 | Returns: 126 | - str: The matched pattern from the prediction text, or the original text if no match. 127 | """ 128 | import re 129 | match = re.search(pattern, raw_pred) 130 | if match: 131 | pred = OutputProcess._base_pred_process(match.group(1)) 132 | if proj_func: 133 | pred = proj_func(pred) 134 | else: 135 | pred = raw_pred 136 | 137 | return pred 138 | -------------------------------------------------------------------------------- /promptbench/utils/defense.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from autocorrect import Speller 5 | 6 | 7 | class Defense(object): 8 | def __init__(self, defense_method='spellcorrect', lang='en'): 9 | self.defense_method = defense_method 10 | if self.defense_method == 'spellcorrect': 11 | self.spell = Speller(lang=lang) 12 | 13 | def __call__(self, text): 14 | if self.defense_method == 'spellcorrect': 15 | return self.spell(text) 16 | else: 17 | raise NotImplementedError 18 | 19 | 20 | if __name__ == '__main__': 21 | defense = Defense() 22 | prompt = 'I am a student at the Univrsity of California, Berkeey.' 23 | print(defense(prompt)) -------------------------------------------------------------------------------- /promptbench/utils/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from ..models import LLMModel 5 | import torch 6 | import numpy as np 7 | import copy 8 | 9 | class Visualizer: 10 | def __init__(self, llm: LLMModel) -> None: 11 | """ 12 | Initialize the Visualizer class. 13 | 14 | Parameters: 15 | - llm (LLMModel): The llm to visualize. 16 | 17 | Attributes: 18 | - model: The inference pipeline of the provided model. 19 | - tokenizer (Tokenizer): Tokenizer associated with the model. 20 | """ 21 | self.model = llm.model.model 22 | self.tokenizer = llm.model.tokenizer 23 | 24 | def _map_subwords_to_words(self, sentence: str): 25 | """ 26 | Convert a sentence into tokens and map subword tokens to their corresponding words. 27 | 28 | Parameters: 29 | - sentence (str): The input sentence. 30 | 31 | Returns: 32 | - mapping (list): List mapping subword tokens to word indices. 33 | - tokens (list): Tokenized version of the input sentence. 34 | """ 35 | tokens = self.tokenizer.tokenize(sentence) 36 | mapping = [] 37 | word_idx = 0 38 | for token in tokens: 39 | if token.startswith("▁"): 40 | mapping.append(word_idx) 41 | word_idx += 1 42 | else: 43 | mapping.append(word_idx - 1) 44 | return mapping, tokens 45 | 46 | def _normalize_importance(self, word_importance): 47 | """ 48 | Normalize importance values of words in a sentence using min-max scaling. 49 | 50 | Parameters: 51 | - word_importance (list): List of importance values for each word. 52 | 53 | Returns: 54 | - list: Normalized importance values for each word. 55 | """ 56 | min_importance = np.min(word_importance) 57 | max_importance = np.max(word_importance) 58 | return (word_importance - min_importance) / (max_importance - min_importance) 59 | 60 | def vis_by_grad(self, input_sentence: str, label: str) -> dict: 61 | """ 62 | Visualize word importance in an input sentence based on gradient information. 63 | 64 | This method uses the gradients of the model's outputs with respect to its 65 | input embeddings to estimate word importance. 66 | 67 | Parameters: 68 | - input_sentence (str): The input sentence. 69 | - label (str): The target label. 70 | 71 | Returns: 72 | - dict: Dictionary with words as keys and their normalized importance as values. 73 | """ 74 | self.model.eval() 75 | 76 | mapping, tokens = self._map_subwords_to_words(input_sentence) 77 | words = "".join(tokens).replace("▁", " ").split() 78 | 79 | inputs = self.tokenizer(input_sentence, return_tensors="pt") 80 | embeddings = self.model.get_input_embeddings()(inputs['input_ids']) 81 | embeddings.requires_grad_() 82 | embeddings.retain_grad() 83 | 84 | labels = self.tokenizer(label, return_tensors="pt")["input_ids"] 85 | outputs = self.model(inputs_embeds=embeddings, attention_mask=inputs['attention_mask'], labels=labels) 86 | outputs.loss.backward() 87 | 88 | grads = embeddings.grad 89 | word_grads = [torch.zeros_like(grads[0][0]) for _ in range(len(words))] # Initialize gradient vectors for each word 90 | 91 | # Aggregate gradients for each word 92 | for idx, grad in enumerate(grads[0][:len(mapping)]): 93 | word_grads[mapping[idx]] += grad 94 | 95 | words_importance = [grad.norm().item() for grad in word_grads] 96 | normalized_importance = self._normalize_importance(words_importance) 97 | 98 | return dict(zip(words, normalized_importance)) 99 | 100 | def vis_by_delete(self, input_sentence: str, label: str) -> dict: 101 | """ 102 | Visualize word importance in an input sentence by deletion method. 103 | 104 | For each word in the sentence, the method deletes it and measures the 105 | change in the model's output. A higher change indicates higher importance. 106 | 107 | Parameters: 108 | - input_sentence (str): The input sentence. 109 | - label (str): The target label. 110 | 111 | Returns: 112 | - dict: Dictionary with words as keys and their normalized importance as values. 113 | """ 114 | words = input_sentence.split() 115 | encoded_label = self.tokenizer(label, return_tensors="pt")["input_ids"] 116 | inputs = self.tokenizer(input_sentence, return_tensors="pt") 117 | original_loss = self.model(**inputs, labels=encoded_label).loss.item() 118 | 119 | word_importance = [] 120 | for i in range(len(words)): 121 | new_words = copy.deepcopy(words) 122 | del new_words[i] 123 | new_sentence = ' '.join(new_words) 124 | inputs = self.tokenizer(new_sentence, return_tensors="pt") 125 | new_loss = self.model(**inputs, labels=encoded_label).loss.item() 126 | 127 | importance = abs(new_loss - original_loss) 128 | word_importance.append(importance) 129 | 130 | normalized_importance = self._normalize_importance(word_importance) 131 | 132 | return dict(zip(words, normalized_importance)) 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autocorrect==2.6.1 2 | accelerate==0.25.0 3 | datasets>=2.15.0 4 | nltk==3.8.1 5 | openai==1.3.7 6 | sentencepiece==0.1.99 7 | tokenizers==0.15.0 8 | torch>=2.1.1 9 | tqdm==4.66.1 10 | transformers==4.38.0 11 | Pillow==10.3.0 12 | google-generativeai==0.4.0 13 | dashscope==1.14.1 14 | einops==0.7.0 15 | transformers_stream_generator==0.0.5 16 | torchvision==0.17.0 17 | matplotlib==3.8.3 18 | tiktoken==0.6.0 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """ Setup 5 | """ 6 | from setuptools import setup, find_packages 7 | from codecs import open 8 | from os import path 9 | import pathlib 10 | 11 | import pkg_resources 12 | 13 | here = path.abspath(path.dirname(__file__)) 14 | 15 | # Get the long description from the README file 16 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 17 | long_description = f.read() 18 | 19 | # requirements 20 | with pathlib.Path('requirements.txt').open() as requirements_txt: 21 | install_requires = [ 22 | str(requirement) 23 | for requirement 24 | in pkg_resources.parse_requirements(requirements_txt) 25 | ] 26 | 27 | setup( 28 | name='promptbench', 29 | version='0.0.4', 30 | description='PromptBench is a powerful tool designed to scrutinize and analyze the interaction of large language models with various prompts. It provides a convenient infrastructure to simulate **black-box** adversarial **prompt attacks** on the models and evaluate their performances.', 31 | long_description=long_description, 32 | long_description_content_type='text/markdown', 33 | url='https://github.com/microsoft/promptbench', 34 | author='', 35 | author_email='', 36 | 37 | # Note that this is a string of words separated by whitespace, not a list. 38 | keywords='pytorch, large language models, prompt tuning, dyval, evaluation', 39 | packages=find_packages(exclude=['examples', 'imgs', 'docs']), 40 | include_package_data=True, 41 | # install_requires=['torch >= 1.8', 'torchvision', 'torchaudio', 'transformers', 'timm', 'progress', 'ruamel.yaml', 'scikit-image', 'scikit-learn', 'tensorflow', ''], 42 | install_requires=install_requires, 43 | python_requires='>=3.9', 44 | ) 45 | --------------------------------------------------------------------------------