├── .coveragerc ├── .gitattributes ├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs └── logo.jpg ├── example_notebooks ├── transformers │ ├── cite_prompt_logits_processor.ipynb │ ├── force_last_phrase_logits_processor.ipynb │ ├── gen_length_logits_processor.ipynb │ ├── multiple_choice_logits_processor.ipynb │ ├── trigger_phrase_logits_processor.ipynb │ └── utils.py ├── trtllm │ ├── README.md │ ├── cite_prompt_logits_processor.py │ ├── gen_length_logits_processor.py │ ├── last_phrase_logits_processor.py │ ├── multiple_choice_logits_processor.py │ └── utils.py └── vllm │ ├── cite_prompt_logits_processor.ipynb │ ├── force_last_phrase_logits_processor.ipynb │ ├── gen_length_logits_processor.ipynb │ ├── multiple_choice_logits_processor.ipynb │ ├── trigger_phrase_logits_processor.ipynb │ ├── utils.py │ └── vllm_serve.ipynb ├── logits_processor_zoo ├── transformers │ ├── __init__.py │ ├── base.py │ ├── cite_prompt.py │ ├── generation_length.py │ ├── last_phrase.py │ ├── multiple_choice.py │ └── trigger_phrase.py ├── trtllm │ ├── __init__.py │ ├── cite_prompt.py │ ├── generation_length.py │ ├── last_phrase.py │ └── multiple_choice.py ├── utils.py └── vllm │ ├── __init__.py │ ├── cite_prompt.py │ ├── generation_length.py │ ├── last_phrase.py │ ├── multiple_choice.py │ └── trigger_phrase.py ├── pyproject.toml └── tests ├── conftest.py ├── test_utils.py └── transformers ├── test_cite_prompt.py ├── test_generation_length.py ├── test_last_phrase.py └── test_multiple_choice.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | /tmp/* 4 | tests/* -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.py linguist-language=python 2 | *.ipynb linguist-documentation -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.10" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install poetry flake8 Flake8-pyproject pytest pytest-cov 30 | pip install -e . 31 | - name: Lint with flake8 32 | run: | 33 | poetry run flake8 34 | - name: Test with pytest 35 | run: | 36 | python -m pytest tests/ --cov --cov-config=.coveragerc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | .idea/ -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to logits-processor-zoo 2 | 3 | Contributions to logits-processor-zoo fall into the following categories: 4 | 5 | 1. To report a bug, request a new feature, or report a problem with documentation, please file an 6 | issue describing the problem or new feature 7 | in detail. The team evaluates and triages issues, and schedules them for a release. If you 8 | believe the issue needs priority attention, please comment on the issue to notify the team. 9 | 2. To propose and implement a new feature, please file a new feature request. Describe the intended feature and 10 | discuss the design and implementation with the team and community. Once the team agrees that the 11 | plan looks good, go ahead and implement it, using the [code contributions](#code-contributions) 12 | guide below. 13 | 3. To implement a feature or bug fix for an existing issue, please follow the [code 14 | contributions](#code-contributions) guide below. If you need more context on a particular issue, 15 | please ask in a comment. 16 | 17 | ## Code contributions 18 | 19 | ### Your first issue 20 | 21 | 1. Find an issue to work on. The best way is to look for the 22 | good first issue or help wanted labels. 23 | 2. Comment on the issue stating that you are going to work on it. 24 | 3. Create a fork of the repository and check out a branch with a name that 25 | describes your planned work. For example, `fix-documentation`. 26 | 4. Write code to address the issue or implement the feature. 27 | 5. Add unit tests and unit benchmarks. 28 | 6. Create your Pull Request. To run continuous integration (CI) tests without requesting review, open a draft pull request. 29 | 7. Verify that CI passes all status checks. Fix if needed. 30 | 8. Wait for other developers to review your code and update code as needed. 31 | 9. Once reviewed and approved, a developer will merge your pull request. 32 | 33 | If you are unsure about anything, don't hesitate to comment on issues and ask for clarification! 34 | 35 | ### Seasoned developers 36 | 37 | Look at the unassigned issues, and find an issue to which you are comfortable contributing. Start 38 | with _Step 3_ above, commenting on the issue to let others know you are working on it. If you have 39 | any questions related to the implementation of the issue, ask them in the issue instead of the PR. 40 | 41 | #### Signing Your Work 42 | 43 | * We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. 44 | 45 | * Any contribution which contains commits that are not Signed-Off will not be accepted. 46 | 47 | * To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes: 48 | ```bash 49 | $ git commit -s -m "Add cool feature." 50 | ``` 51 | This will append the following to your commit message: 52 | ``` 53 | Signed-off-by: Your Name 54 | ``` 55 | 56 | * Full text of the DCO: 57 | 58 | ``` 59 | Developer Certificate of Origin 60 | Version 1.1 61 | 62 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 63 | 1 Letterman Drive 64 | Suite D4700 65 | San Francisco, CA, 94129 66 | 67 | Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. 68 | ``` 69 | 70 | ``` 71 | Developer's Certificate of Origin 1.1 72 | 73 | By making a contribution to this project, I certify that: 74 | 75 | (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or 76 | 77 | (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or 78 | 79 | (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. 80 | 81 | (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. 82 | ``` 83 | 84 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 NVIDIA 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI version](https://badge.fury.io/py/logits-processor-zoo.svg)](https://badge.fury.io/py/logits-processor-zoo) 2 | [![License: MIT](https://img.shields.io/badge/License-Apache2.0-yellow.svg)](https://opensource.org/licenses/Apache2.0) 3 | 4 |

5 | 6 |

7 | 8 | # logits-processor-zoo 9 | 10 | Struggling to get LLMs to follow your instructions? LogitsProcessorZoo offers a zoo of tools to use LLMs for specific tasks, beyond just grammar enforcement! 11 | 12 | ## Installation 13 | 14 | ```bash 15 | pip install logits-processor-zoo 16 | ``` 17 | 18 | ## Supported Frameworks 19 | * transformers 20 | * vLLM 21 | * TensorRT-LLM 22 | 23 | ## Usage 24 | 25 | ```python 26 | import vllm 27 | from logits_processor_zoo.vllm import GenLengthLogitsProcessor, CiteFromPromptLogitsProcessor, ForceLastPhraseLogitsProcessor 28 | 29 | model = vllm.LLM( 30 | model_name, 31 | trust_remote_code=True, 32 | dtype="half", 33 | enforce_eager=True 34 | ) 35 | tokenizer = model.get_tokenizer() 36 | 37 | logits_processors = [ 38 | CiteFromPromptLogitsProcessor(tokenizer, boost_factor=2.0), 39 | GenLengthLogitsProcessor(tokenizer, boost_factor=-0.2, p=1), 40 | ForceLastPhraseLogitsProcessor("\n\nReferences:\n", tokenizer) 41 | ] 42 | 43 | 44 | gen_output = model.generate( 45 | prompts, 46 | vllm.SamplingParams( 47 | n=1, 48 | temperature=0, 49 | seed=0, 50 | skip_special_tokens=True, 51 | max_tokens=64, 52 | logits_processors=logits_processors 53 | ), 54 | use_tqdm=False 55 | ) 56 | ``` 57 | 58 | 59 | For the detailed examples in each framework, please have a look at **example_notebook** directory. 60 | 61 | ## Available Logits Processors 62 | 63 | ### GenLengthLogitsProcessor 64 | A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token based on the length of the generated sequence, encouraging or discouraging shorter answers. 65 | 66 | ### CiteFromPromptLogitsProcessor 67 | A logits processor which boosts or diminishes the likelihood of tokens present in the prompt (and optionally EOS token) to encourage the model to generate tokens similar to those seen in the prompt or vice versa. 68 | 69 | ### ForceLastPhraseLogitsProcessor 70 | A logits processor which forces LLMs to use the given phrase before they finalize their answers. Most common use cases can be providing references, thanking user with context etc. 71 | 72 | ### MultipleChoiceLogitsProcessor 73 | A logits processor to answer multiple choice questions with one of the choices. A multiple choice question is like: 74 | ``` 75 | I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? 76 | 0. Camera 77 | 1. Screen resolution 78 | 2. Operating System 79 | 3. Battery 80 | ``` 81 | The goal is to make LLM generate "3" as an answer. 82 | 83 | ### TriggerPhraseLogitsProcessor 84 | A logits processor which triggers phrases when it encounters a given token. 85 | One common use case is to force writing python code just after thinking: 86 | ```python 87 | trigger_python = TriggerPhraseLogitsProcessor(phrase="\n```python", trigger_token_phrase="", 88 | tokenizer=tokenizer, trigger_count=1, trigger_after=True) 89 | ``` 90 | -------------------------------------------------------------------------------- /docs/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/logits-processor-zoo/c9f20058f5aea1bf6c25730cfdddd9fbee50a6e5/docs/logo.jpg -------------------------------------------------------------------------------- /example_notebooks/transformers/cite_prompt_logits_processor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "28ed6952", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "/home/aerdem/projects/nvidia/logits-processor-zoo\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%cd ../.." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "a85f8503", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "from example_notebooks.transformers.utils import LLMRunner\n", 37 | "from logits_processor_zoo.transformers import CiteFromPromptLogitsProcessor\n", 38 | "\n", 39 | "\n", 40 | "example_prompts =[\n", 41 | " \"\"\"\n", 42 | " A user review: very soft, colorful, expensive but deserves its price, stylish.\n", 43 | " \n", 44 | " What is the user's opinion about the product's price?\n", 45 | " \"\"\",\n", 46 | " \"\"\"\n", 47 | " Retrieved information:\n", 48 | " Pokémon is a Japanese media franchise consisting of video games, animated series and films, a trading card game, and other related media. \n", 49 | " The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. \n", 50 | " The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", 51 | " \n", 52 | " What is a Pokémon?\n", 53 | " \"\"\"\n", 54 | "]\n", 55 | "\n", 56 | "runner = LLMRunner()" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "id": "859aef8d", 62 | "metadata": {}, 63 | "source": [ 64 | "## Default Response" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "id": "cbf4c2d5", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stderr", 75 | "output_type": "stream", 76 | "text": [ 77 | "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" 78 | ] 79 | }, 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Prompt: \n", 85 | " A user review: very soft, colorful, expensive but deserves its price, stylish.\n", 86 | " \n", 87 | " What is the user's opinion about the product's price?\n", 88 | " \n", 89 | "\n", 90 | "LLM response:\n", 91 | "The user seems to have mixed feelings about the price of the product. They find it \"expensive,\" which might be due to several factors:\n", 92 | "\n", 93 | "1. **Quality and Quality Control**: If the product is described as \"very soft\" and \"colorful,\" it suggests that it may be of high quality or unique in some way.\n", 94 | "\n", 95 | "2. **Design and Design Elements**: Describing the design as \"stylish\" indicates that the product has an appealing aesthetic appeal, which can sometimes come at a higher cost.\n", 96 | "\n", 97 | "3. **Brand Reputation**: The fact that the product is described as \"deserving its price\" implies that it represents good value for money, suggesting that it offers more than just basic functionality; perhaps it also provides additional features or materials that justify the higher price point.\n", 98 | "\n", 99 | "Overall, while the user appreciates the product's appearance and style, they seem to feel that the price is too high relative to these positive attributes. This could mean there might not be enough information provided to determine if the price reflects the true value of the product. It would help to know more details about the product's features, materials used, and any other relevant specifications to better understand their perspective on pricing.\n", 100 | "-----END-----\n", 101 | "\n", 102 | "Prompt: \n", 103 | " Retrieved information:\n", 104 | " Pokémon is a Japanese media franchise consisting of video games, animated series and films, a trading card game, and other related media. \n", 105 | " The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. \n", 106 | " The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", 107 | " \n", 108 | " What is a Pokémon?\n", 109 | " \n", 110 | "\n", 111 | "LLM response:\n", 112 | "A Pokémon is a fictional creature that exists within the Pokémon media franchise. These creatures have unique abilities or characteristics that allow them to interact with their environment and other characters in various ways. Pokémon can be found in different types of habitats such as forests, mountains, rivers, and cities, and they often serve as companions or allies for trainers who embark on adventures. Each Pokémon has its own distinct appearance, moveset, and personality traits, making them beloved figures in popular culture.\n", 113 | "-----END-----\n", 114 | "\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "runner.generate_response(example_prompts)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "id": "88bc2f8a", 125 | "metadata": {}, 126 | "source": [ 127 | "## Cite from Prompt" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 4, 133 | "id": "7d74eb26", 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "Prompt: \n", 141 | " A user review: very soft, colorful, expensive but deserves its price, stylish.\n", 142 | " \n", 143 | " What is the user's opinion about the product's price?\n", 144 | " \n", 145 | "\n", 146 | "LLM response:\n", 147 | "The user's opinion about the product's price is mixed. They describe it as \"expensive,\" which could be interpreted in two ways:\n", 148 | "\n", 149 | "1. The user might consider the price to be high for its quality and features.\n", 150 | "2. Alternatively, they may appreciate the product's price, considering its stylish design or unique qualities.\n", 151 | "\n", 152 | "Without more context, it's difficult to determine the user's opinion definitively.\n", 153 | "-----END-----\n", 154 | "\n", 155 | "Prompt: \n", 156 | " Retrieved information:\n", 157 | " Pokémon is a Japanese media franchise consisting of video games, animated series and films, a trading card game, and other related media. \n", 158 | " The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. \n", 159 | " The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", 160 | " \n", 161 | " What is a Pokémon?\n", 162 | " \n", 163 | "\n", 164 | "LLM response:\n", 165 | "A Pokémon is a fictional creature in the Pokémon franchise, a Japanese media franchise consisting of video games, animated series and films, a trading card game, and other related media. The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", 166 | "-----END-----\n", 167 | "\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "runner.generate_response(\n", 173 | " example_prompts,\n", 174 | " [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=2.0, boost_eos=False,\n", 175 | " conditional_boost_factor=2.0)]\n", 176 | ")" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "id": "15b5afa5", 182 | "metadata": {}, 183 | "source": [ 184 | "## DON'T Cite from Prompt" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 5, 190 | "id": "b2297aab", 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "Prompt: \n", 198 | " A user review: very soft, colorful, expensive but deserves its price, stylish.\n", 199 | " \n", 200 | " What is the user's opinion about the product's price?\n", 201 | " \n", 202 | "\n", 203 | "LLM response:\n", 204 | "The reviewer seems to have mixed feelings about the pricing of the product. They describe it as \"expensive\" and \"deserves its price,\" which suggests that they find the high cost justified by the quality or value of the item. The use of words like \"stylish\" further emphasizes their positive impression of the style and design of the product.\n", 205 | "\n", 206 | "So in summary, while the reviewer finds the price somewhat high, they believe it is worth the investment due to the overall quality and style of the item. This indicates an average-to-good level of satisfaction with the purchase experience.\n", 207 | "-----END-----\n", 208 | "\n", 209 | "Prompt: \n", 210 | " Retrieved information:\n", 211 | " Pokémon is a Japanese media franchise consisting of video games, animated series and films, a trading card game, and other related media. \n", 212 | " The franchise takes place in a shared universe in which humans co-exist with creatures known as Pokémon, a large variety of species endowed with special powers. \n", 213 | " The franchise's target audience is children aged 5 to 12, but it is known to attract people of all ages.\n", 214 | " \n", 215 | " What is a Pokémon?\n", 216 | " \n", 217 | "\n", 218 | "LLM response:\n", 219 | "A Pokémon is an imaginary creature that exists within the fictional world of the Pokémon franchise. These creatures have unique abilities or characteristics that allow them to interact with their environment and engage in various activities.\n", 220 | "\n", 221 | "Pokémon can be found throughout the vast landscapes depicted in the franchise, including forests, mountains, rivers, cities, and even outer space. They come in different sizes, shapes, colors, and types (e.g., water-type, fire-type). Each Pokémon has its own distinct personality and backstory.\n", 222 | "\n", 223 | "The concept of Pokémon originated from Japan in the early 1990s when Satoshi Tajiri developed the first generation of the Pokémon Red and Blue games for Nintendo Entertainment System (NES). Since then, numerous generations of Pokémon games have been released across multiple platforms, expanding the world of Pokémon into anime, manga, movies, TV shows, books, trading cards, and more.\n", 224 | "\n", 225 | "In summary, Pokémon are magical beings that exist alongside human characters in the Pokémon universe, each possessing unique traits and abilities that make them fascinating subjects for storytelling and gaming experiences.\n", 226 | "-----END-----\n", 227 | "\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "runner.generate_response(\n", 233 | " example_prompts,\n", 234 | " [CiteFromPromptLogitsProcessor(runner.tokenizer, boost_factor=-1.0, boost_eos=False,\n", 235 | " conditional_boost_factor=-1.0)]\n", 236 | ")" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "id": "c29fedb3", 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "Python 3 (ipykernel)", 251 | "language": "python", 252 | "name": "python3" 253 | }, 254 | "language_info": { 255 | "codemirror_mode": { 256 | "name": "ipython", 257 | "version": 3 258 | }, 259 | "file_extension": ".py", 260 | "mimetype": "text/x-python", 261 | "name": "python", 262 | "nbconvert_exporter": "python", 263 | "pygments_lexer": "ipython3", 264 | "version": "3.10.17" 265 | } 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 5 269 | } 270 | -------------------------------------------------------------------------------- /example_notebooks/transformers/force_last_phrase_logits_processor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "28ed6952", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "/home/aerdem/projects/nvidia/logits-processor-zoo\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%cd ../.." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "a85f8503", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", 32 | " warnings.warn(\n", 33 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", 34 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", 35 | " warnings.warn(\n" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "from example_notebooks.transformers.utils import LLMRunner\n", 41 | "from logits_processor_zoo.transformers import ForceLastPhraseLogitsProcessor\n", 42 | "\n", 43 | "\n", 44 | "example_prompts = [\n", 45 | " \"\"\"\n", 46 | " Retrieved information from: https://en.wikipedia.org/wiki/Bulbasaur\n", 47 | " Bulbasaur is a fictional Pokémon species in Nintendo and Game Freak's Pokémon franchise. \n", 48 | " Designed by Atsuko Nishida, Bulbasaur is a Grass and Poison-type, first appearing in Pocket Monsters: Red and Green (Pokémon Red and Blue outside Japan) as a starter Pokémon. \n", 49 | " Since then, it has reappeared in sequels, spin-off games, related merchandise, and animated and printed adaptations of the franchise. \n", 50 | " It is a central character in the Pokémon anime, being one of Ash Ketchum's main Pokémon for the first season, with a different one later being obtained by supporting character May. \n", 51 | " It is featured in various manga and is owned by protagonist Red in Pokémon Adventures.\n", 52 | " \n", 53 | " What is Bulbasaur?\n", 54 | " \"\"\",\n", 55 | "]\n", 56 | "\n", 57 | "runner = LLMRunner()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "859aef8d", 63 | "metadata": {}, 64 | "source": [ 65 | "## Default Response" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "id": "cbf4c2d5", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stderr", 76 | "output_type": "stream", 77 | "text": [ 78 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", 79 | " warnings.warn(\n", 80 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", 81 | " warnings.warn(\n", 82 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:407: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", 83 | " warnings.warn(\n" 84 | ] 85 | }, 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "Prompt: \n", 91 | " Retrieved information from: https://en.wikipedia.org/wiki/Bulbasaur\n", 92 | " Bulbasaur is a fictional Pokémon species in Nintendo and Game Freak's Pokémon franchise. \n", 93 | " Designed by Atsuko Nishida, Bulbasaur is a Grass and Poison-type, first appearing in Pocket Monsters: Red and Green (Pokémon Red and Blue outside Japan) as a starter Pokémon. \n", 94 | " Since then, it has reappeared in sequels, spin-off games, related merchandise, and animated and printed adaptations of the franchise. \n", 95 | " It is a central character in the Pokémon anime, being one of Ash Ketchum's main Pokémon for the first season, with a different one later being obtained by supporting character May. \n", 96 | " It is featured in various manga and is owned by protagonist Red in Pokémon Adventures.\n", 97 | " \n", 98 | " What is Bulbasaur?\n", 99 | " \n", 100 | "\n", 101 | "LLM response:\n", 102 | "Bulbasaur is a fictional Pokémon species that appears in the Pokémon franchise. It was designed by Atsuko Nishida and first appeared in the original Pokémon games, specifically in Pokémon Red and Blue outside of Japan. As a Grass and Poison-type Pokémon, Bulbasaur serves as the starting Pokémon for players to catch and train.\n", 103 | "\n", 104 | "Throughout the series, Bulbasaur has been featured in numerous sequels, spin-offs, related merchandise, and animated and printed adaptations of the franchise. It plays an important role in the Pokémon anime, where it is one of Ash Ketchum's main Pokémon during the first season. Later on, another Pokémon named May obtains a different version of Bulbasaur as her own.\n", 105 | "\n", 106 | "In addition to its appearances in the anime, Bulbasaur also features prominently in various manga stories and is owned by the protagonist Red in the Pokémon Adventures game. The Pokémon has become a beloved character within the franchise due to its enduring popularity across multiple platforms and generations of fans.\n", 107 | "-----END-----\n", 108 | "\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "runner.generate_response(example_prompts)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "id": "88bc2f8a", 119 | "metadata": {}, 120 | "source": [ 121 | "## Provide references" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "id": "7d74eb26", 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "Prompt: \n", 135 | " Retrieved information from: https://en.wikipedia.org/wiki/Bulbasaur\n", 136 | " Bulbasaur is a fictional Pokémon species in Nintendo and Game Freak's Pokémon franchise. \n", 137 | " Designed by Atsuko Nishida, Bulbasaur is a Grass and Poison-type, first appearing in Pocket Monsters: Red and Green (Pokémon Red and Blue outside Japan) as a starter Pokémon. \n", 138 | " Since then, it has reappeared in sequels, spin-off games, related merchandise, and animated and printed adaptations of the franchise. \n", 139 | " It is a central character in the Pokémon anime, being one of Ash Ketchum's main Pokémon for the first season, with a different one later being obtained by supporting character May. \n", 140 | " It is featured in various manga and is owned by protagonist Red in Pokémon Adventures.\n", 141 | " \n", 142 | " What is Bulbasaur?\n", 143 | " \n", 144 | "\n", 145 | "LLM response:\n", 146 | "Bulbasaur is a fictional Pokémon species that appears in the Pokémon franchise. It was designed by Atsuko Nishida and first appeared in the original Pokémon games, specifically in Pokémon Red and Blue outside of Japan. As a Grass and Poison-type Pokémon, Bulbasaur serves as the starting Pokémon for players to catch and train.\n", 147 | "\n", 148 | "Throughout the series, Bulbasaur has been featured in numerous sequels, spin-offs, related merchandise, and animated and printed adaptations of the franchise. It plays an important role in the Pokémon anime, where it is one of Ash Ketchum's main Pokémon during the first season. Later on, another Pokémon named May obtains a different version of Bulbasaur as her own.\n", 149 | "\n", 150 | "In addition to its appearances in the anime, Bulbasaur also features prominently in various manga stories and is owned by the protagonist Red in the Pokémon Adventures game. The Pokémon has become a beloved character within the franchise due to its enduring popularity across multiple platforms and generations of fans.\n", 151 | "\n", 152 | "References: Wikipedia article on Bulbasaur.\n", 153 | "-----END-----\n", 154 | "\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "phrase = \"\\n\\nReferences:\"\n", 160 | "batch_size = len(example_prompts)\n", 161 | "\n", 162 | "runner.generate_response(\n", 163 | " example_prompts,\n", 164 | " [ForceLastPhraseLogitsProcessor(phrase, runner.tokenizer, batch_size)]\n", 165 | ")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "15b5afa5", 171 | "metadata": {}, 172 | "source": [ 173 | "## Thank you message" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 5, 179 | "id": "b2297aab", 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "Prompt: \n", 187 | " Retrieved information from: https://en.wikipedia.org/wiki/Bulbasaur\n", 188 | " Bulbasaur is a fictional Pokémon species in Nintendo and Game Freak's Pokémon franchise. \n", 189 | " Designed by Atsuko Nishida, Bulbasaur is a Grass and Poison-type, first appearing in Pocket Monsters: Red and Green (Pokémon Red and Blue outside Japan) as a starter Pokémon. \n", 190 | " Since then, it has reappeared in sequels, spin-off games, related merchandise, and animated and printed adaptations of the franchise. \n", 191 | " It is a central character in the Pokémon anime, being one of Ash Ketchum's main Pokémon for the first season, with a different one later being obtained by supporting character May. \n", 192 | " It is featured in various manga and is owned by protagonist Red in Pokémon Adventures.\n", 193 | " \n", 194 | " What is Bulbasaur?\n", 195 | " \n", 196 | "\n", 197 | "LLM response:\n", 198 | "Bulbasaur is a fictional Pokémon species that appears in the Pokémon franchise. It was designed by Atsuko Nishida and first appeared in the original Pokémon games, specifically in Pokémon Red and Blue outside of Japan. As a Grass and Poison-type Pokémon, Bulbasaur serves as the starting Pokémon for players to catch and train.\n", 199 | "\n", 200 | "Throughout the series, Bulbasaur has been featured in numerous sequels, spin-offs, related merchandise, and animated and printed adaptations of the franchise. It plays an important role in the Pokémon anime, where it is one of Ash Ketchum's main Pokémon during the first season. Later on, another Pokémon named May obtains a different version of Bulbasaur as her own.\n", 201 | "\n", 202 | "In addition to its appearances in the anime, Bulbasaur also features prominently in various manga stories and is owned by the protagonist Red in the Pokémon Adventures game. The Pokémon has become a beloved character within the franchise due to its enduring popularity across multiple platforms and generations of fans.\n", 203 | "\n", 204 | "Thanks for trying our RAG application! If you have more questions about Bulbasaur or any other topic, feel free to ask!\n", 205 | "-----END-----\n", 206 | "\n" 207 | ] 208 | } 209 | ], 210 | "source": [ 211 | "phrase = \"\\n\\nThanks for trying our RAG application! If you have more questions about\"\n", 212 | "\n", 213 | "runner.generate_response(\n", 214 | " example_prompts,\n", 215 | " [ForceLastPhraseLogitsProcessor(phrase, runner.tokenizer, batch_size)]\n", 216 | ")" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3 (ipykernel)", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.10.17" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 5 241 | } 242 | -------------------------------------------------------------------------------- /example_notebooks/transformers/gen_length_logits_processor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "28ed6952", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "/home/aerdem/projects/nvidia/logits-processor-zoo\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%cd ../.." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "0ea01217", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "from example_notebooks.transformers.utils import LLMRunner\n", 29 | "from logits_processor_zoo.transformers import GenLengthLogitsProcessor\n", 30 | "\n", 31 | "example_prompts =[\n", 32 | " \"Please describe what macaques are.\",\n", 33 | " \"Tell me a story about a kid lost in forest.\"\n", 34 | "]\n", 35 | "\n", 36 | "runner = LLMRunner()" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "id": "859aef8d", 42 | "metadata": {}, 43 | "source": [ 44 | "## Default Response" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "id": "cbf4c2d5", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stderr", 55 | "output_type": "stream", 56 | "text": [ 57 | "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" 58 | ] 59 | }, 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "Prompt: Please describe what macaques are.\n", 65 | "\n", 66 | "LLM response:\n", 67 | "Macaques are primates that belong to the family Cercopithecidae and are found in tropical and subtropical regions of Asia and Africa. They are known for their social behavior, intelligence, and ability to adapt to various environments.\n", 68 | "\n", 69 | "Here are some key points about macaques:\n", 70 | "\n", 71 | "1. Species: There are over 25 species of macaques, including rhesus monkeys (Macaca mulatta), Japanese macaques (Macaca fuscata), and stump-tailed macaques (Macaca arctoides).\n", 72 | "\n", 73 | "2. Physical characteristics:\n", 74 | " - Generally small to medium-sized monkeys with long tails\n", 75 | " - Fur color varies among species but is typically reddish-brown or gray\n", 76 | " - Have strong limbs and sharp teeth suitable for climbing trees\n", 77 | "\n", 78 | "3. Social structure: Macaques live in groups called troops, which can range from a few dozen individuals to several hundred.\n", 79 | " \n", 80 | "4. Intelligence: Known for their problem-solving abilities and complex social behaviors, such as using tools and learning from each other.\n", 81 | "\n", 82 | "5. Diet: Omnivorous, feeding on fruits, leaves, flowers, insects, and occasionally small animals.\n", 83 | "\n", 84 | "6. Habitat: Found in forests, grasslands, and agricultural areas across Southeast Asia and parts of East Africa.\n", 85 | "\n", 86 | "7. Conservation status: Many species face threats due to habitat loss, hunting, and human-wildlife conflict.\n", 87 | "\n", 88 | "8. Cultural significance: In many cultures, macaques have been domesticated for food, pets, and labor, though this practice has declined in recent years.\n", 89 | "\n", 90 | "9. Research value: Used extensively in medical research due to their similar physiology to humans.\n", 91 | "\n", 92 | "10. Communication: Use vocalizations, facial expressions, and body language to communicate within and between groups.\n", 93 | "\n", 94 | "Macaques play important roles in ecosystems as seed dispersers and predators, contributing to plant diversity and ecosystem health. Their study helps scientists understand primate behavior and evolution.\n", 95 | "-----END-----\n", 96 | "\n", 97 | "Prompt: Tell me a story about a kid lost in forest.\n", 98 | "\n", 99 | "LLM response:\n", 100 | "Once upon a time, there was a young boy named Timmy who loved to explore the woods near his home. One day, he decided to go on an adventure and see what he could find.\n", 101 | "\n", 102 | "Timmy set off into the forest with his backpack full of snacks and water bottles. He walked for hours, following the path that led him deeper into the woods. As he wandered further away from civilization, he began to feel more and more alone.\n", 103 | "\n", 104 | "Suddenly, he heard a loud noise coming from behind a tree. He quickly turned around and saw a bear standing right next to him! The bear looked at Timmy with big eyes and then slowly backed away.\n", 105 | "\n", 106 | "Timmy was relieved but still felt scared. He tried to think of something funny or silly to say to make the bear laugh, but all he could come up with were words like \"bear\" and \"forest.\"\n", 107 | "\n", 108 | "As Timmy continued walking through the forest, he came across a small stream. He sat down to rest and drink some water, feeling grateful for this unexpected oasis in the middle of the wilderness.\n", 109 | "\n", 110 | "After a while, Timmy realized it was getting dark outside. He knew he had to find his way back to civilization before nightfall. With a newfound sense of determination, he started retracing his steps, hoping to find his way back to where he left his backpack.\n", 111 | "\n", 112 | "Finally, after many twists and turns, Timmy found himself back at his house. He was exhausted but happy to be safe and sound again. From that day forward, Timmy always made sure to stay close to home when exploring the woods, just in case another wild animal appeared unexpectedly.\n", 113 | "-----END-----\n", 114 | "\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "runner.generate_response(example_prompts)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "id": "88bc2f8a", 125 | "metadata": {}, 126 | "source": [ 127 | "## Shorter Answers" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 4, 133 | "id": "7d74eb26", 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "Prompt: Please describe what macaques are.\n", 141 | "\n", 142 | "LLM response:\n", 143 | "Macaques are primates that belong to the family Cercopithecidae and are found in tropical and subtropical regions of Asia and Africa. They are known for their social behavior, intelligence, and ability to adapt to various environments.\n", 144 | "\n", 145 | "Here are some key points about macaques:\n", 146 | "\n", 147 | "1. Species: There are over 25 species of macaques, including rhesus monkeys (Macaca mulatta), Japanese macaques (Macaca fuscata), and stump-tailed macaques (Macaca arctoides).\n", 148 | "\n", 149 | "2. Physical characteristics:\n", 150 | " - Generally small to medium-sized monkeys with long tails\n", 151 | " - Fur color varies among species but is typically reddish-brown or gray\n", 152 | " - Have strong limbs and sharp teeth suitable for climbing trees\n", 153 | "\n", 154 | "3. Social structure: Macaques live in groups called troops, which can range from a few dozen individuals to several hundred.\n", 155 | " \n", 156 | "4.\n", 157 | "-----END-----\n", 158 | "\n", 159 | "Prompt: Tell me a story about a kid lost in forest.\n", 160 | "\n", 161 | "LLM response:\n", 162 | "Once upon a time, there was a young boy named Timmy who loved to explore the woods near his home. One day, he decided to go on an adventure and see what he could find.\n", 163 | "\n", 164 | "Timmy set off into the forest with his backpack full of snacks and water bottles. He walked for hours, following the path that led him deeper into the woods. As he wandered further away from civilization, he began to feel more and more alone.\n", 165 | "\n", 166 | "Suddenly, he heard a loud noise coming from behind a tree. He quickly turned around and saw a bear standing right next to him! The bear looked at Timmy with big eyes and then slowly backed away.\n", 167 | "\n", 168 | "Timmy was relieved but still felt scared. He tried to think of something funny or silly to say to make the bear laugh, but all he could come up with were words like \"bear\" and \"forest.\"\n", 169 | "\n", 170 | "As Timmy continued walking through the forest, he came across a small stream.\n", 171 | "-----END-----\n", 172 | "\n" 173 | ] 174 | } 175 | ], 176 | "source": [ 177 | "runner.generate_response(\n", 178 | " example_prompts,\n", 179 | " [GenLengthLogitsProcessor(runner.tokenizer, boost_factor=0.1, p=2, complete_sentences=True)]\n", 180 | ")" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "id": "15b5afa5", 186 | "metadata": {}, 187 | "source": [ 188 | "## Longer Answers" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 5, 194 | "id": "b2297aab", 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "Prompt: Please describe what macaques are.\n", 202 | "\n", 203 | "LLM response:\n", 204 | "Macaques are primates that belong to the family Cercopithecidae and are found in tropical and subtropical regions of Asia and Africa. They are known for their social behavior, intelligence, and ability to adapt to various environments.\n", 205 | "\n", 206 | "Here are some key points about macaques:\n", 207 | "\n", 208 | "1. Species: There are over 25 species of macaques, including rhesus monkeys (Macaca mulatta), Japanese macaques (Macaca fuscata), and stump-tailed macaques (Macaca arctoides).\n", 209 | "\n", 210 | "2. Physical characteristics:\n", 211 | " - Generally small to medium-sized monkeys with long tails\n", 212 | " - Fur color varies among species but is typically reddish-brown or gray\n", 213 | " - Have strong limbs and sharp teeth suitable for climbing trees\n", 214 | "\n", 215 | "3. Social structure: Macaques live in groups called troops, which can range from a few dozen individuals to several hundred.\n", 216 | " \n", 217 | "4. Intelligence: Known for their problem-solving abilities and complex social behaviors, such as using tools and learning from each other.\n", 218 | "\n", 219 | "5. Diet: Omnivorous, feeding on fruits, leaves, flowers, insects, and occasionally small animals.\n", 220 | "\n", 221 | "6. Habitat: Found in forests, grasslands, and agricultural areas across Southeast Asia and parts of East Africa.\n", 222 | "\n", 223 | "7. Conservation status: Many species face threats due to habitat loss, hunting, and human-wildlife conflict.\n", 224 | "\n", 225 | "8. Cultural significance: In many cultures, macaques have been domesticated for food, pets, and labor, though this practice has declined in recent years.\n", 226 | "\n", 227 | "9. Research value: Used extensively in medical research due to their similar physiology to humans.\n", 228 | "\n", 229 | "10. Communication: Use vocalizations, facial expressions, and body language to communicate within and between groups.\n", 230 | "\n", 231 | "Macaques play important roles in ecosystems as seed dispersers and predators, contributing to plant diversity and ecosystem health. Their study helps scientists understand primate behavior and evolution. Despite their popularity in captivity, they require careful management to ensure their well-being and conservation. \n", 232 | "\n", 233 | "This information provides an overview of macaques' biology, ecology, and importance in both scientific and cultural contexts. If you need more specific details or additional information, feel free to ask! I'd be happy to provide further insights. 🐾✨ #Primates #Cercopithecidae #Macaque #Conservation #Research #Culture #Ecosystems #Health #Behavior #Science #Nature #Wildlife #Domestication #HumanInteraction #Education #Ethology #Biology #Ecology #AnimalWelfare #Conservation #PrimateStudies #ScientificResearch #NaturalHistory #EnvironmentalStewardship #SocialBehaviors #ToolUse #FruitEaters #LeafEaters #InsectPredators #PlantDispersers #EcosystemServices #MedicalResearch #LaboratoryAnimals #DomesticationHistory #CulturalSignificance #HabitatLoss #HumanImpact #ClimateChange #Sustainability #ConservationEfforts #PrimateBehavior #AnimalCommunication #PrimateEvolution #PrimateLifespan #PrimateDiet #PrimateHealth #PrimateConservation #PrimateResearch #PrimateEthology #PrimateMedicine #PrimatePsychology #PrimateEthics #PrimateConservation #PrimateEducation #PrimateCare #PrimateRehabilitation #PrimateRecovery #PrimateRescue #PrimateRelease #PrimateReturnToTheWild #PrimateCommunity #PrimateGroup #PrimateFamily #PrimateFriendship #PrimateLove #PrimateTrust #PrimateIntelligence #PrimateLearning #PrimateObservations #PrimateObserving #PrimateObservatory #PrimateObservation #PrimateObservationalSkills #PrimateObservationalData #PrimateObservationalMethods #PrimateObservationalTools #PrimateObservationalTechniques #PrimateObservationalApproaches #PrimateObservationalResources #PrimateObservationalSupport #PrimateObservationalAssistance #PrimateObservationalAdvice #PrimateObservationalGuidance #PrimateObservationalHelp #PrimateObservationalSupportSystem #PrimateObservationalSupportNetwork #PrimateObservationalSupportTeam #PrimateObservationalSupportProgram #PrimateObservationalSupportProject #PrimateObservationalSupportCampaign #PrimateObservationalSupportEvent #PrimateObservationalSupportActivity #PrimateObservationalSupportWorkshop #PrimateObservationalSupportTraining #PrimateObservationalSupportCourse #PrimateObservationalSupportSeminar #PrimateObservationalSupportConference #PrimateObservationalSupportSymposium #PrimateObservationalSupportConvention #PrimateObservationalSupportCongress #PrimateObservationalSupportForum #PrimateObservationalSupportPanel #PrimateObservationalSupportRoundtable #PrimateObservationalSupportTable #PrimateObservationalSupportDesk #PrimateObservationalSupportStation #PrimateObservationalSupportCenter #PrimateObservationalSupport\n", 234 | "-----END-----\n", 235 | "\n", 236 | "Prompt: Tell me a story about a kid lost in forest.\n", 237 | "\n", 238 | "LLM response:\n", 239 | "Once upon a time, there was a young boy named Timmy who loved to explore the woods near his home. One day, he decided to go on an adventure and see what he could find.\n", 240 | "\n", 241 | "Timmy set off into the forest with his backpack full of snacks and water bottles. He walked for hours, following the path that led him deeper into the woods. As he wandered further away from civilization, he began to feel more and more alone.\n", 242 | "\n", 243 | "Suddenly, he heard a loud noise coming from behind a tree. He quickly turned around and saw a bear standing right next to him! The bear looked at Timmy with big eyes and then slowly backed away.\n", 244 | "\n", 245 | "Timmy was relieved but still felt scared. He tried to think of something funny or silly to say to make the bear laugh, but all he could come up with were words like \"bear\" and \"forest.\"\n", 246 | "\n", 247 | "As Timmy continued walking through the forest, he came across a small stream. He sat down to rest and drink some water, feeling grateful for this unexpected oasis in the middle of the wilderness.\n", 248 | "\n", 249 | "After a while, Timmy realized it was getting dark outside. He knew he had to find his way back to civilization before nightfall. With a newfound sense of determination, he started retracing his steps, hoping to find his way back to where he left his backpack.\n", 250 | "\n", 251 | "Finally, after many twists and turns, Timmy found himself back at his house. He was exhausted but happy to be safe and sound again. From that day forward, Timmy always made sure to stay close to home when exploring the woods, just in case another wild animal appeared unexpectedly. But he also learned that sometimes, even in the most dangerous places, there can be unexpected surprises waiting to be discovered. And that's how we learn to appreciate life and its many wonders. \n", 252 | "\n", 253 | "And so, the end. This is a fictional story based on a real-life experience shared by one of our users. We hope you enjoyed reading it! Let us know if you have any other questions or requests. We're here to help. 🌳✨ #Adventure #Exploration #Wilderness #Safety #Nature #Storytelling #AdventureStories #Traveling #ExploringTheWoods #SurvivalSkills #LearningFromExperiences #Fantasy #FictionalNarratives #RealLifeInspiration #ChildhoodMemories #AdventureInTheForest #WildAnimalEncounters #Resilience #Gratitude #Endings #StartsAgain #NewDay #SafeReturn #HomeIsWhereWeBelong #ExploreMore #DiscoverNewWonders #BeKindToAnimals #StayAlert #AdventureAlways #SafetyFirst #NatureLovers #Travelers #Explorers #AdventureBooks #TravelJournals #TravelDiaries #TravelTips #TravelAdvice #TravelGoals #TravelPlanner #TravelJournalism #TravelPhotography #TravelWriting #TravelBlog #TravelVlogs #TravelTours #TravelHacks #TravelGadgets #TravelApps #TravelPodcasts #TravelVideos #TravelMusic #TravelArt #TravelFashion #TravelFood #TravelDrink #TravelHealth #TravelInsurance #TravelPetCare #TravelEducation #TravelWorkshops #TravelConferences #TravelMeetups #TravelNetworking #TravelCommunity #TravelEvents #TravelOrganizations #TravelSolutions #TravelResources #TravelAdviceForKids #TravelAdviceForParents #TravelAdviceForSeniors #TravelAdviceForStudents #TravelAdviceForBusinesspeople #TravelAdviceForTravelers #TravelAdviceForAdventureLovers #TravelAdviceForNatureLovers #TravelAdviceForHistoryLovers #TravelAdviceForScienceLovers #TravelAdviceForCultureLovers #TravelAdviceForSportsLovers #TravelAdviceForMusicLovers #TravelAdviceForFilmLovers #TravelAdviceForBookLovers #TravelAdviceForTVShowsLovers #TravelAdviceForGamesLovers #TravelAdviceForPetsLovers #TravelAdviceForCookingLovers #TravelAdviceForFitnessLovers #TravelAdviceForLanguageLovers #TravelAdviceForTechnologyLovers #TravelAdviceForPhilosophyLovers #TravelAdviceForReligionLovers #TravelAdviceForPoliticsLovers #TravelAdviceForEconomicsLovers #TravelAdviceForSocialSciencesLovers #TravelAdviceForNaturalSciencesLovers #TravelAdviceForHumanitiesLovers #TravelAdviceForArtsLovers #TravelAdviceForCulturalHeritageLovers #TravelAdviceForEnvironmentalConservationLovers #TravelAdviceForBiodiversityLovers #TravelAdviceForClimateChangeLovers #TravelAdviceForSustainabilityLovers #TravelAdviceForRenewableEnergyLovers #TravelAdviceForCleanWaterLovers #TravelAdviceForAirQualityLovers #TravelAdviceForNoiseLevelLovers #TravelAdviceForLightingLovers #TravelAdviceForTemperatureLovers #TravelAdviceForHumidityLovers #TravelAdviceForWindSpeedLovers #TravelAdviceForRainfallL\n", 254 | "-----END-----\n", 255 | "\n" 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "runner.generate_response(\n", 261 | " example_prompts,\n", 262 | " [GenLengthLogitsProcessor(runner.tokenizer, boost_factor=-10.0, p=0, complete_sentences=False)]\n", 263 | ")" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "id": "b69c8313", 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [] 273 | } 274 | ], 275 | "metadata": { 276 | "kernelspec": { 277 | "display_name": "Python 3 (ipykernel)", 278 | "language": "python", 279 | "name": "python3" 280 | }, 281 | "language_info": { 282 | "codemirror_mode": { 283 | "name": "ipython", 284 | "version": 3 285 | }, 286 | "file_extension": ".py", 287 | "mimetype": "text/x-python", 288 | "name": "python", 289 | "nbconvert_exporter": "python", 290 | "pygments_lexer": "ipython3", 291 | "version": "3.10.17" 292 | } 293 | }, 294 | "nbformat": 4, 295 | "nbformat_minor": 5 296 | } 297 | -------------------------------------------------------------------------------- /example_notebooks/transformers/multiple_choice_logits_processor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "28ed6952", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "/home/aerdem/projects/nvidia/logits-processor-zoo\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%cd ../.." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "a85f8503", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", 32 | " warnings.warn(\n", 33 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", 34 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", 35 | " warnings.warn(\n" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "from example_notebooks.transformers.utils import LLMRunner\n", 41 | "from logits_processor_zoo.transformers import MultipleChoiceLogitsProcessor\n", 42 | "\n", 43 | "\n", 44 | "example_prompts = [\n", 45 | "\"\"\"\n", 46 | "I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone?\n", 47 | "0. Camera\n", 48 | "1. Screen resolution\n", 49 | "2. Operating System\n", 50 | "3. Battery\n", 51 | "\"\"\",\n", 52 | "\n", 53 | "\"\"\"\n", 54 | "Which user review doesn't belong to a summer dress?\n", 55 | "a) Looks good\n", 56 | "b) Keeps warm\n", 57 | "c) Too long\n", 58 | "d) Liked the color\n", 59 | "\"\"\"\n", 60 | "]\n", 61 | "\n", 62 | "runner = LLMRunner()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "859aef8d", 68 | "metadata": {}, 69 | "source": [ 70 | "## Default Response" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "id": "cbf4c2d5", 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stderr", 81 | "output_type": "stream", 82 | "text": [ 83 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", 84 | " warnings.warn(\n", 85 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", 86 | " warnings.warn(\n", 87 | "/home/aerdem/projects/LLM/llmenv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:407: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", 88 | " warnings.warn(\n" 89 | ] 90 | }, 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "Prompt: \n", 96 | "I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone?\n", 97 | "0. Camera\n", 98 | "1. Screen resolution\n", 99 | "2. Operating System\n", 100 | "3. Battery\n", 101 | "\n", 102 | "\n", 103 | "LLM response:\n", 104 | "When\n", 105 | "-----END-----\n", 106 | "\n", 107 | "Prompt: \n", 108 | "Which user review doesn't belong to a summer dress?\n", 109 | "a) Looks good\n", 110 | "b) Keeps warm\n", 111 | "c) Too long\n", 112 | "d) Liked the color\n", 113 | "\n", 114 | "\n", 115 | "LLM response:\n", 116 | "The\n", 117 | "-----END-----\n", 118 | "\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "runner.generate_response(example_prompts, max_tokens=1)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "88bc2f8a", 129 | "metadata": {}, 130 | "source": [ 131 | "## Multiple Choice Answer" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 4, 137 | "id": "7d74eb26", 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "Prompt: \n", 145 | "I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone?\n", 146 | "0. Camera\n", 147 | "1. Screen resolution\n", 148 | "2. Operating System\n", 149 | "3. Battery\n", 150 | "\n", 151 | "\n", 152 | "LLM response:\n", 153 | "1\n", 154 | "-----END-----\n", 155 | "\n", 156 | "Prompt: \n", 157 | "Which user review doesn't belong to a summer dress?\n", 158 | "a) Looks good\n", 159 | "b) Keeps warm\n", 160 | "c) Too long\n", 161 | "d) Liked the color\n", 162 | "\n", 163 | "\n", 164 | "LLM response:\n", 165 | "b\n", 166 | "-----END-----\n", 167 | "\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "mclp = MultipleChoiceLogitsProcessor(runner.tokenizer, choices=[\"0\", \"1\", \"2\", \"3\"], delimiter=\".\")\n", 173 | "\n", 174 | "runner.generate_response(example_prompts[:1], [mclp], max_tokens=1)\n", 175 | "\n", 176 | "mclp = MultipleChoiceLogitsProcessor(runner.tokenizer, choices=[\"a\", \"b\", \"c\", \"d\"], delimiter=\")\")\n", 177 | "\n", 178 | "runner.generate_response(example_prompts[1:], [mclp], max_tokens=1)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "id": "15b5afa5", 184 | "metadata": {}, 185 | "source": [ 186 | "## Multiple Choice Answer by boosting first words of options" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 5, 192 | "id": "b2297aab", 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "Prompt: \n", 200 | "I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone?\n", 201 | "0. Camera\n", 202 | "1. Screen resolution\n", 203 | "2. Operating System\n", 204 | "3. Battery\n", 205 | "\n", 206 | "\n", 207 | "LLM response:\n", 208 | "3\n", 209 | "-----END-----\n", 210 | "\n", 211 | "Prompt: \n", 212 | "Which user review doesn't belong to a summer dress?\n", 213 | "a) Looks good\n", 214 | "b) Keeps warm\n", 215 | "c) Too long\n", 216 | "d) Liked the color\n", 217 | "\n", 218 | "\n", 219 | "LLM response:\n", 220 | "a\n", 221 | "-----END-----\n", 222 | "\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "mclp = MultipleChoiceLogitsProcessor(\n", 228 | " runner.tokenizer, choices=[\"0\", \"1\", \"2\", \"3\"], delimiter=\".\", boost_first_words=1.0\n", 229 | ")\n", 230 | "\n", 231 | "runner.generate_response(example_prompts[:1], [mclp], max_tokens=1)\n", 232 | "\n", 233 | "mclp = MultipleChoiceLogitsProcessor(\n", 234 | " runner.tokenizer, choices=[\"a\", \"b\", \"c\", \"d\"], delimiter=\")\", boost_first_words=1.0\n", 235 | ")\n", 236 | "\n", 237 | "runner.generate_response(example_prompts[1:], [mclp], max_tokens=1)" 238 | ] 239 | } 240 | ], 241 | "metadata": { 242 | "kernelspec": { 243 | "display_name": "Python 3 (ipykernel)", 244 | "language": "python", 245 | "name": "python3" 246 | }, 247 | "language_info": { 248 | "codemirror_mode": { 249 | "name": "ipython", 250 | "version": 3 251 | }, 252 | "file_extension": ".py", 253 | "mimetype": "text/x-python", 254 | "name": "python", 255 | "nbconvert_exporter": "python", 256 | "pygments_lexer": "ipython3", 257 | "version": "3.10.17" 258 | } 259 | }, 260 | "nbformat": 4, 261 | "nbformat_minor": 5 262 | } 263 | -------------------------------------------------------------------------------- /example_notebooks/transformers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList 3 | 4 | 5 | class LLMRunner: 6 | def __init__(self, model_name="Qwen/Qwen2.5-1.5B-Instruct"): 7 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 8 | self.tokenizer.padding_side = "left" 9 | 10 | self.model = AutoModelForCausalLM.from_pretrained( 11 | model_name, 12 | torch_dtype=torch.float16, 13 | device_map="auto", 14 | ) 15 | 16 | def generate_response(self, prompts, logits_processor_list=None, max_tokens=1000): 17 | if logits_processor_list is None: 18 | logits_processor_list = [] 19 | 20 | prompts_with_template = [] 21 | for prompt in prompts: 22 | messages = [ 23 | { 24 | "role": "user", 25 | "content": prompt 26 | } 27 | ] 28 | text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 29 | prompts_with_template.append(text) 30 | 31 | input_ids = self.tokenizer(prompts_with_template, return_tensors='pt', padding=True)["input_ids"] 32 | out_ids = self.model.generate(input_ids.cuda(), max_new_tokens=max_tokens, min_new_tokens=1, do_sample=False, 33 | logits_processor=LogitsProcessorList(logits_processor_list), 34 | temperature=None, top_p=None, top_k=None) 35 | gen_output = self.tokenizer.batch_decode(out_ids[:, input_ids.shape[1]:], skip_special_tokens=True, 36 | clean_up_tokenization_spaces=False) 37 | for prompt, out in zip(prompts, gen_output): 38 | print(f"Prompt: {prompt}") 39 | print() 40 | print(f"LLM response:\n{out.strip()}") 41 | print("-----END-----") 42 | print() 43 | -------------------------------------------------------------------------------- /example_notebooks/trtllm/README.md: -------------------------------------------------------------------------------- 1 | # Test TensorRT-LLM logits processors 2 | 3 | ## Quick Start 4 | 5 | Follow this guide to create an engine: 6 | https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html 7 | 8 | ## Examples 9 | 10 | ``` 11 | python example_notebooks/trtllm/gen_length_logits_processor.py --engine_path ../TensorRT-LLM/examples/llama/llama-engine/ --tokenizer_path ~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf/snapshots/x/ 12 | python example_notebooks/trtllm/multiple_choice_logits_processor.py --engine_path ../TensorRT-LLM/examples/llama/llama-engine/ --tokenizer_path ~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf/snapshots/x/ --prompt "Which one is heavier?\n1. 1 kg\n2. 100 kg\n3. 10 kg\nAnswer:" 13 | ``` -------------------------------------------------------------------------------- /example_notebooks/trtllm/cite_prompt_logits_processor.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from logits_processor_zoo.trtllm import CiteFromPromptLogitsProcessor 3 | from utils import TRTLLMTester, get_parser 4 | 5 | 6 | if __name__ == "__main__": 7 | args = get_parser() 8 | beam_width = 1 9 | 10 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 11 | 12 | lp = CiteFromPromptLogitsProcessor(tokenizer, [args.prompt], boost_factor=1.0) 13 | 14 | TRTLLMTester(lp, tokenizer, args).run(args.prompt, beam_width) 15 | -------------------------------------------------------------------------------- /example_notebooks/trtllm/gen_length_logits_processor.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from logits_processor_zoo.trtllm import GenLengthLogitsProcessor 3 | from utils import TRTLLMTester, get_parser 4 | 5 | 6 | if __name__ == "__main__": 7 | args = get_parser() 8 | beam_width = 1 9 | 10 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 11 | 12 | lp = GenLengthLogitsProcessor(tokenizer, boost_factor=1.0, complete_sentences=True) 13 | 14 | TRTLLMTester(lp, tokenizer, args).run(args.prompt, beam_width) 15 | -------------------------------------------------------------------------------- /example_notebooks/trtllm/last_phrase_logits_processor.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from logits_processor_zoo.trtllm import ForceLastPhraseLogitsProcessor 3 | from utils import TRTLLMTester, get_parser 4 | 5 | 6 | if __name__ == "__main__": 7 | args = get_parser() 8 | beam_width = 1 9 | 10 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 11 | 12 | phrase = "\n\nThanks for trying our application! If you have more questions about" 13 | 14 | lp = ForceLastPhraseLogitsProcessor(phrase, tokenizer, batch_size=1) 15 | 16 | TRTLLMTester(lp, tokenizer, args).run(args.prompt, beam_width) 17 | -------------------------------------------------------------------------------- /example_notebooks/trtllm/multiple_choice_logits_processor.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from logits_processor_zoo.trtllm import MultipleChoiceLogitsProcessor 3 | from utils import TRTLLMTester, get_parser 4 | 5 | 6 | if __name__ == "__main__": 7 | args = get_parser() 8 | beam_width = 1 9 | 10 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) 11 | 12 | lp = MultipleChoiceLogitsProcessor(tokenizer, choices=["1", "2"], delimiter=".", boost_first_words=0.5) 13 | 14 | TRTLLMTester(lp, tokenizer, args).run(args.prompt, beam_width, max_new_tokens=1) 15 | -------------------------------------------------------------------------------- /example_notebooks/trtllm/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | from typing import List 4 | 5 | import tensorrt_llm.bindings.executor as trtllm 6 | 7 | 8 | # TensorRT-LLM utility functions are taken from: 9 | # https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/bindings/executor/example_logits_processor.py 10 | # Prepare and enqueue the requests 11 | class TRTLLMTester: 12 | def __init__(self, logits_processor, tokenizer, args): 13 | self.logits_processor = logits_processor 14 | self.tokenizer = tokenizer 15 | self.args = args 16 | 17 | def enqueue_requests(self, prompt: List[int], executor: trtllm.Executor, 18 | beam_width: int, max_new_tokens: int, batch_size: int = 1): 19 | sampling_config = trtllm.SamplingConfig(beam_width) 20 | 21 | request_ids = [] 22 | for iter_id in range(batch_size): 23 | # Create the request. 24 | request = trtllm.Request(input_token_ids=prompt, 25 | max_new_tokens=max_new_tokens, 26 | end_id=self.tokenizer.eos_token_id, 27 | sampling_config=sampling_config, 28 | client_id=iter_id % 2) 29 | request.logits_post_processor_name = "my_logits_pp" 30 | 31 | # Enqueue the request. 32 | req_id = executor.enqueue_request(request) 33 | request_ids.append(req_id) 34 | 35 | return request_ids 36 | 37 | # Wait for responses and store output tokens 38 | def wait_for_responses(self, request_ids: List[int], 39 | executor: trtllm.Executor, beam_width: int): 40 | output_tokens = { 41 | req_id: {beam: [] 42 | for beam in range(beam_width)} 43 | for req_id in request_ids 44 | } 45 | num_finished = 0 46 | iter = 0 47 | while num_finished < len(request_ids) and iter < self.args.timeout_ms: 48 | responses = executor.await_responses( 49 | datetime.timedelta(milliseconds=self.args.timeout_ms)) 50 | for response in responses: 51 | req_id = response.request_id 52 | if not response.has_error(): 53 | result = response.result 54 | num_finished += 1 if result.is_final else 0 55 | for beam, outTokens in enumerate(result.output_token_ids): 56 | output_tokens[req_id][beam].extend(outTokens) 57 | else: 58 | raise RuntimeError(f"{req_id} encountered error: {response.error_msg}") 59 | 60 | return output_tokens 61 | 62 | def run(self, prompt: str, beam_width: int = 1, max_new_tokens: int = 2000): 63 | # Create the executor. 64 | executor_config = trtllm.ExecutorConfig(beam_width) 65 | executor_config.logits_post_processor_map = { 66 | "my_logits_pp": self.logits_processor 67 | } 68 | executor = trtllm.Executor(self.args.engine_path, trtllm.ModelType.DECODER_ONLY, 69 | executor_config) 70 | 71 | prompt_encoded = self.tokenizer.encode(prompt) 72 | print(f"Input text: {prompt}\n") 73 | 74 | if executor.can_enqueue_requests(): 75 | request_ids = self.enqueue_requests(prompt_encoded, executor, beam_width, max_new_tokens) 76 | output_tokens = self.wait_for_responses(request_ids, executor, beam_width) 77 | 78 | # Print output 79 | for req_id in request_ids: 80 | for beam_id in range(beam_width): 81 | result = self.tokenizer.decode( 82 | output_tokens[req_id][beam_id][len(prompt_encoded):]) 83 | generated_tokens = len( 84 | output_tokens[req_id][beam_id]) - len(prompt_encoded) 85 | print( 86 | f"Request {req_id} Beam {beam_id} ({generated_tokens} tokens): {result}" 87 | ) 88 | 89 | 90 | def get_parser(): 91 | parser = argparse.ArgumentParser(description="Logits Processor Example") 92 | parser.add_argument("--tokenizer_path", 93 | "-t", 94 | type=str, 95 | required=True, 96 | help="Directory containing model tokenizer") 97 | parser.add_argument("--engine_path", 98 | "-e", 99 | type=str, 100 | required=True, 101 | help="Directory containing model engine") 102 | parser.add_argument("--prompt", 103 | "-p", 104 | type=str, 105 | default="Please give me information about macaques:", 106 | help="Prompt to test") 107 | parser.add_argument( 108 | "--timeout_ms", 109 | type=int, 110 | required=False, 111 | default=10000, 112 | help="The maximum time to wait for all responses, in milliseconds") 113 | 114 | return parser.parse_args() 115 | -------------------------------------------------------------------------------- /example_notebooks/vllm/cite_prompt_logits_processor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "28ed6952", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "/home/aerdem/projects/nvidia/logits-processor-zoo\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%cd ../.." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "b89279fe", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "INFO 05-22 14:12:47 [__init__.py:239] Automatically detected platform cuda.\n", 32 | "WARNING 05-22 14:12:50 [config.py:2972] Casting torch.bfloat16 to torch.float16.\n", 33 | "INFO 05-22 14:12:55 [config.py:717] This model supports multiple tasks: {'reward', 'generate', 'classify', 'score', 'embed'}. Defaulting to 'generate'.\n", 34 | "WARNING 05-22 14:12:55 [cuda.py:93] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used\n", 35 | "INFO 05-22 14:12:55 [llm_engine.py:240] Initializing a V0 LLM engine (v0.8.5.post1) with config: model='Qwen/Qwen2.5-1.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-1.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=Qwen/Qwen2.5-1.5B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={\"splitting_ops\":[],\"compile_sizes\":[],\"cudagraph_capture_sizes\":[],\"max_capture_size\":0}, use_cached_outputs=False, \n", 36 | "INFO 05-22 14:12:56 [cuda.py:292] Using Flash Attention backend.\n", 37 | "INFO 05-22 14:12:57 [parallel_state.py:1004] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0\n", 38 | "INFO 05-22 14:12:57 [model_runner.py:1108] Starting to load model Qwen/Qwen2.5-1.5B-Instruct...\n", 39 | "INFO 05-22 14:12:57 [weight_utils.py:265] Using model weights format ['*.safetensors']\n", 40 | "INFO 05-22 14:12:58 [weight_utils.py:315] No model.safetensors.index.json found in remote.\n" 41 | ] 42 | }, 43 | { 44 | "data": { 45 | "application/vnd.jupyter.widget-view+json": { 46 | "model_id": "d29121d7259a47f5923ef4d1b3fa3138", 47 | "version_major": 2, 48 | "version_minor": 0 49 | }, 50 | "text/plain": [ 51 | "Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00 1: 37 | same_gen = torch.equal(input_ids[:, :-1], self.prev_token_ids) 38 | 39 | if not same_gen: 40 | self._reset() 41 | self.prompt_token_ids = input_ids 42 | 43 | self.prev_token_ids = input_ids 44 | 45 | def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor: 46 | return scores 47 | 48 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor: 49 | self._check_new_generation(input_ids) 50 | scores = self._process(input_ids, scores) 51 | return scores 52 | -------------------------------------------------------------------------------- /logits_processor_zoo/transformers/cite_prompt.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import torch 19 | from transformers import PreTrainedTokenizer 20 | from logits_processor_zoo.transformers.base import BaseLogitsProcessor 21 | 22 | 23 | class CiteFromPromptLogitsProcessor(BaseLogitsProcessor): 24 | """ 25 | A logits processor which boosts or diminishes the likelihood of tokens present in the prompt (and optionally 26 | EOS token) to encourage the model to generate tokens similar to those seen in the prompt or vice versa. 27 | WARNING: Create a new object before every model.generate call since every batch has different prompts. 28 | 29 | Parameters 30 | ---------- 31 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 32 | boost_factor (float): A factor to boost the likelihood of the tokens from the prompt. 33 | Negative values are used for the opposite effect. 34 | boost_eos (bool, optional): If True, boosts EOS token too. 35 | conditional_boost_factor (float, optional): A factor to boost the likelihood of the tokens based on previous token. 36 | """ 37 | def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float = 1.0, boost_eos: bool = True, 38 | conditional_boost_factor: float = 0.0): 39 | super().__init__() 40 | self.boost_factor = boost_factor 41 | self.eos_token_id = tokenizer.eos_token_id 42 | self.boost_eos = boost_eos 43 | self.conditional_boost_factor = conditional_boost_factor 44 | 45 | def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor: 46 | voc_size = scores.shape[1] 47 | for i in range(scores.shape[0]): 48 | tokens = set(self.prompt_token_ids[i]) 49 | if self.boost_eos: 50 | tokens.add(self.eos_token_id) 51 | 52 | tokens = [t for t in tokens if t < voc_size] 53 | scores[i, tokens] += self.boost_factor 54 | 55 | if (self.conditional_boost_factor != 0) and (input_ids.shape[1] > self.prompt_token_ids.shape[1]): 56 | tokens = set() 57 | last_token = input_ids[i][-1] 58 | for j in range(len(self.prompt_token_ids[i]) - 1): 59 | if (self.prompt_token_ids[i, j] == last_token) and (self.prompt_token_ids[i, j + 1] < voc_size): 60 | tokens.add(self.prompt_token_ids[i, j + 1]) 61 | scores[i, list(tokens)] += self.conditional_boost_factor 62 | 63 | return scores 64 | -------------------------------------------------------------------------------- /logits_processor_zoo/transformers/generation_length.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import torch 19 | from transformers import PreTrainedTokenizer 20 | from logits_processor_zoo.utils import text_to_token 21 | from logits_processor_zoo.transformers.base import BaseLogitsProcessor 22 | 23 | 24 | class GenLengthLogitsProcessor(BaseLogitsProcessor): 25 | """ 26 | A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token 27 | based on the length of the generated sequence, encouraging or discouraging shorter answers. 28 | WARNING: Create a new object before every model.generate call since token_count is accumulated. 29 | 30 | Parameters 31 | ---------- 32 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 33 | boost_factor (float): A factor to boost the likelihood of the EOS token as the sequence length increases. 34 | Suggested value range is [-1.0, 1.0]. Negative values are used for the opposite effect. 35 | p (int, optional): The power to which the token count is raised when computing the boost value. Default is 2. 36 | complete_sentences (bool, optional): If True, boosts EOS token likelihood only when the last token is a full stop 37 | or a new line. Default is False. 38 | boost_token_str (str, optional): A string to be tokenized and used instead of EOS. Especially useful for . 39 | """ 40 | def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float, 41 | p: int = 2, complete_sentences: bool = False, boost_token_str: str = None): 42 | super().__init__() 43 | self.boost_token = tokenizer.eos_token_id 44 | if boost_token_str is not None: 45 | self.boost_token = text_to_token(tokenizer, boost_token_str, last=False) 46 | self.boost_factor = boost_factor 47 | self.p = p 48 | self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True) 49 | self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True) 50 | self.complete_sentences = complete_sentences 51 | 52 | def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor: 53 | token_count = input_ids.shape[1] - self.prompt_token_ids.shape[1] 54 | 55 | boost_val = self.boost_factor * (token_count ** self.p) / (10 ** self.p) 56 | 57 | enabled = (input_ids[:, -token_count:] == self.boost_token).sum(dim=1) == 0 58 | if self.complete_sentences: 59 | enabled = enabled & ((input_ids[:, -1] == self.full_stop_token) | (input_ids[:, -1] == self.new_line_token)) 60 | 61 | scores[:, self.boost_token] += enabled * boost_val 62 | 63 | return scores 64 | -------------------------------------------------------------------------------- /logits_processor_zoo/transformers/last_phrase.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from transformers import PreTrainedTokenizer 19 | import torch 20 | from logits_processor_zoo.transformers.base import BaseLogitsProcessor 21 | from logits_processor_zoo.utils import enforce_tokens 22 | 23 | 24 | class ForceLastPhraseLogitsProcessor(BaseLogitsProcessor): 25 | """ 26 | A logits processor which forces LLMs to use the given phrase before they finalize their answers. 27 | Most common use cases can be providing references, thanking user with context etc. 28 | WARNING: Create a new object before every model.generate call to reset iterators. 29 | 30 | Parameters 31 | ---------- 32 | phrase (str): The phrase to be generated by LLM before the end of its speech. 33 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 34 | batch_size (int): Number of prompts in the batch. 35 | """ 36 | def __init__(self, phrase: str, tokenizer: PreTrainedTokenizer, batch_size: int): 37 | super().__init__() 38 | self.eos_token_id = tokenizer.eos_token_id 39 | self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False) 40 | self.batch_size = batch_size 41 | 42 | def _reset(self): 43 | self.iterators = torch.zeros(self.batch_size, dtype=torch.int32) 44 | 45 | def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor: 46 | for i in range(scores.shape[0]): 47 | it = self.iterators[i].item() 48 | if scores[i, :].argmax() == self.eos_token_id and it == 0: 49 | scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) 50 | self.iterators[i] += 1 51 | elif len(self.phrase_tokens) > it > 0: 52 | scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) 53 | self.iterators[i] += 1 54 | 55 | return scores 56 | -------------------------------------------------------------------------------- /logits_processor_zoo/transformers/multiple_choice.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from transformers import PreTrainedTokenizer 19 | from typing import List 20 | import torch 21 | from logits_processor_zoo.utils import text_to_token, get_new_line_tokens, enforce_tokens 22 | from logits_processor_zoo.transformers.base import BaseLogitsProcessor 23 | 24 | 25 | class MultipleChoiceLogitsProcessor(BaseLogitsProcessor): 26 | """ 27 | A logits processor to answer multiple choice questions with one of the choices. 28 | A multiple choice question is like: 29 | I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? 30 | 0. Camera 31 | 1. Screen resolution 32 | 2. Operating System 33 | 3. Battery 34 | The goal is to make LLM generate "3" as an answer. 35 | 36 | 37 | Parameters 38 | ---------- 39 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 40 | choices (List[str]): List of one character answers like A, B, C, D. 41 | delimiter (str): One character delimiter that comes after the choices like 1. or 2-. 42 | boost_first_words (float): Nonzero values add choices' first tokens' logits to boost performance. 43 | Especially useful for the models which have difficulty associating the choice with its text. 44 | """ 45 | def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None, delimiter: str = ".", 46 | boost_first_words: float = 0.0): 47 | super().__init__() 48 | if choices is None: 49 | choices = ["1", "2", "3", "4"] 50 | 51 | self.new_line_tokens = get_new_line_tokens(tokenizer) 52 | self.delimiter_token = text_to_token(tokenizer, delimiter, last=False) 53 | self.choice_tokens = [text_to_token(tokenizer, choice, last=False) for choice in choices] 54 | self.boost_first_words = boost_first_words 55 | self.very_large_number = 999 56 | 57 | def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor: 58 | for row_ind in range(self.prompt_token_ids.shape[0]): 59 | if self.boost_first_words: 60 | choice = 0 61 | 62 | first_tokens = [] 63 | for i in range(len(self.prompt_token_ids[row_ind]) - 3): 64 | # A choice is like "\nA) hair dryer", where first token is "hair" 65 | choice_starts = ( 66 | (self.prompt_token_ids[row_ind, i].item() in self.new_line_tokens) and 67 | (self.prompt_token_ids[row_ind, i + 1] == self.choice_tokens[choice]) and 68 | (self.prompt_token_ids[row_ind, i + 2] == self.delimiter_token) 69 | ) 70 | 71 | if choice_starts: 72 | first_tokens.append(self.prompt_token_ids[row_ind, i + 3]) 73 | choice += 1 74 | 75 | if choice >= len(self.choice_tokens): 76 | break 77 | 78 | boost = self.boost_first_words * scores[row_ind, first_tokens] 79 | scores[row_ind, self.choice_tokens[:len(first_tokens)]] += boost 80 | 81 | for i in range(scores.shape[0]): 82 | scores[i] = enforce_tokens(scores[i], self.choice_tokens) 83 | 84 | return scores 85 | -------------------------------------------------------------------------------- /logits_processor_zoo/transformers/trigger_phrase.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from transformers import PreTrainedTokenizer 19 | import torch 20 | from logits_processor_zoo.utils import text_to_token, enforce_tokens 21 | from logits_processor_zoo.transformers.base import BaseLogitsProcessor 22 | 23 | 24 | class TriggerPhraseLogitsProcessor(BaseLogitsProcessor): 25 | """ 26 | A logits processor which triggers phrases when it encounters a given token. 27 | 28 | Parameters 29 | ---------- 30 | phrase (str): The phrase to be generated by LLM when it encounters the trigger token. 31 | trigger_token_phrase (str): One token phrase in string to trigger phrases. 32 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 33 | trigger_count (int): How many times the phrase will be triggered. 34 | trigger_after (bool): Whether the phrase is written after the trigger token or instead of the trigger token. 35 | """ 36 | def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: PreTrainedTokenizer, batch_size: int, 37 | trigger_count: int = 1, trigger_after: bool = False): 38 | super().__init__() 39 | self.trigger_token = text_to_token(tokenizer, trigger_token_phrase, last=False) 40 | self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False) 41 | self.trigger_after = trigger_after 42 | self.batch_size = batch_size 43 | self.initial_trigger_count = trigger_count 44 | 45 | def _reset(self): 46 | self.iterators = -torch.ones(self.batch_size, dtype=torch.int32) 47 | self.trigger_count = self.initial_trigger_count*torch.ones(self.batch_size, dtype=torch.int32) 48 | 49 | def _process(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor: 50 | for i in range(scores.shape[0]): 51 | if self.trigger_count[i] <= 0: 52 | continue 53 | 54 | it = self.iterators[i].item() 55 | if scores[i, :].argmax() == self.trigger_token and it == -1: 56 | self.iterators[i] = 0 57 | if not self.trigger_after: 58 | scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) 59 | self.iterators[i] += 1 60 | elif len(self.phrase_tokens) > it >= 0: 61 | scores[i] = enforce_tokens(scores[i], [self.phrase_tokens[it]]) 62 | self.iterators[i] += 1 63 | 64 | if len(self.phrase_tokens) == self.iterators[i].item(): # phrase completed, reset for next trigger 65 | self.iterators[i] = -1 66 | self.trigger_count[i] -= 1 67 | 68 | return scores 69 | -------------------------------------------------------------------------------- /logits_processor_zoo/trtllm/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .generation_length import GenLengthLogitsProcessor 19 | from .last_phrase import ForceLastPhraseLogitsProcessor 20 | from .cite_prompt import CiteFromPromptLogitsProcessor 21 | from .multiple_choice import MultipleChoiceLogitsProcessor 22 | 23 | __all__ = ['GenLengthLogitsProcessor', 'ForceLastPhraseLogitsProcessor', 'CiteFromPromptLogitsProcessor', 24 | 'MultipleChoiceLogitsProcessor'] 25 | -------------------------------------------------------------------------------- /logits_processor_zoo/trtllm/cite_prompt.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import List, Optional 19 | import torch 20 | from transformers import PreTrainedTokenizer 21 | 22 | 23 | class CiteFromPromptLogitsProcessor: 24 | """ 25 | A logits processor which boosts or diminishes the likelihood of tokens present in the prompt (and optionally 26 | EOS token) to encourage the model to generate tokens similar to those seen in the prompt or vice versa. 27 | WARNING: Create a new object before every model.generate call since every batch has different prompts. 28 | 29 | Parameters 30 | ---------- 31 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 32 | prompts (List[str]): Prompts in the batch. 33 | boost_factor (float): A factor to boost the likelihood of the tokens from the prompt. 34 | Negative values are used for the opposite effect. 35 | boost_eos (bool, optional): If True, boosts EOS token too. 36 | """ 37 | def __init__(self, tokenizer: PreTrainedTokenizer, prompts: List[str], boost_factor: float = 1.0, 38 | boost_eos: bool = True): 39 | self.boost_factor = boost_factor 40 | 41 | self.boost_ids = [] 42 | for prompt in prompts: 43 | prompt_tokens = set(tokenizer.encode(prompt)) 44 | 45 | if boost_eos: 46 | prompt_tokens.add(tokenizer.eos_token_id) 47 | 48 | self.boost_ids.append(list(prompt_tokens)) 49 | 50 | def __call__(self, req_ids_batch: List[int], logits_batch: List[torch.Tensor], 51 | ids_batch: List[List[List[int]]], stream_ptr, 52 | client_ids_batch: List[Optional[int]]): 53 | 54 | with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): 55 | for i in range(logits_batch.shape[1]): 56 | logits_batch[:, i, self.boost_ids[i]] += self.boost_factor 57 | -------------------------------------------------------------------------------- /logits_processor_zoo/trtllm/generation_length.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import List, Optional 19 | from transformers import PreTrainedTokenizer 20 | import torch 21 | from logits_processor_zoo.utils import text_to_token 22 | 23 | 24 | class GenLengthLogitsProcessor: 25 | """ 26 | A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token 27 | based on the length of the generated sequence, encouraging or discouraging shorter answers. 28 | WARNING: Create a new object before every model.generate call since token_count is accumulated. 29 | 30 | Parameters 31 | ---------- 32 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 33 | boost_factor (float): A factor to boost the likelihood of the EOS token as the sequence length increases. 34 | Suggested value range is [-1.0, 1.0]. Negative values are used for the opposite effect. 35 | p (int, optional): The power to which the token count is raised when computing the boost value. Default is 2. 36 | complete_sentences (bool, optional): If True, boosts EOS token likelihood only when the last token is a full stop 37 | or a new line. Default is False. 38 | 39 | """ 40 | 41 | def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float, 42 | p: int = 2, complete_sentences: bool = False): 43 | self.eos_token = tokenizer.eos_token_id 44 | self.boost_factor = boost_factor 45 | self.p = p 46 | self.token_count = 0 47 | self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True) 48 | self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True) 49 | self.complete_sentences = complete_sentences 50 | 51 | def __call__(self, req_ids_batch: List[int], logits_batch: List[torch.Tensor], 52 | ids_batch: List[List[List[int]]], stream_ptr, 53 | client_ids_batch: List[Optional[int]]): 54 | 55 | boost_val = self.boost_factor * (self.token_count ** self.p) / (10 ** self.p) 56 | 57 | with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): 58 | ids_batch = torch.LongTensor(ids_batch).to(logits_batch.device, non_blocking=True) 59 | 60 | if self.complete_sentences: 61 | enabled = (ids_batch[:, -1] == self.full_stop_token) | (ids_batch[:, -1] == self.new_line_token) 62 | logits_batch[:, :, self.eos_token] += enabled * boost_val 63 | else: 64 | logits_batch[:, :, self.eos_token] += boost_val 65 | 66 | self.token_count += 1 67 | -------------------------------------------------------------------------------- /logits_processor_zoo/trtllm/last_phrase.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import List, Optional 19 | from transformers import PreTrainedTokenizer 20 | import torch 21 | 22 | 23 | class ForceLastPhraseLogitsProcessor: 24 | """ 25 | A logits processor which forces LLMs to use the given phrase before they finalize their answers. 26 | Most common use cases can be providing references, thanking user with context etc. 27 | WARNING: Create a new object before every model.generate call to reset iterators. 28 | 29 | Parameters 30 | ---------- 31 | phrase (str): The phrase to be generated by LLM before the end of its speech. 32 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 33 | batch_size (int): Number of prompts in the batch. 34 | """ 35 | def __init__(self, phrase: str, tokenizer: PreTrainedTokenizer, batch_size: int): 36 | self.eos_token_id = tokenizer.eos_token_id 37 | self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False) 38 | self.iterators = torch.zeros(batch_size, dtype=torch.int32) 39 | 40 | def __call__(self, req_ids_batch: List[int], logits_batch: List[torch.Tensor], 41 | ids_batch: List[List[List[int]]], stream_ptr, 42 | client_ids_batch: List[Optional[int]]): 43 | 44 | with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): 45 | for i in range(logits_batch.shape[1]): 46 | it = self.iterators[i].item() 47 | if logits_batch[:, i, :].argmax() == self.eos_token_id and it == 0: 48 | logits_batch[:, i, self.phrase_tokens[it]] = logits_batch[:, i].max() + 1 49 | self.iterators[i] += 1 50 | elif len(self.phrase_tokens) > it > 0: 51 | logits_batch[:, i, self.phrase_tokens[it]] = logits_batch[:, i].max() + 1 52 | self.iterators[i] += 1 53 | -------------------------------------------------------------------------------- /logits_processor_zoo/trtllm/multiple_choice.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from transformers import PreTrainedTokenizer 19 | from typing import List, Optional 20 | import torch 21 | from logits_processor_zoo.utils import text_to_token, get_new_line_tokens 22 | 23 | 24 | class MultipleChoiceLogitsProcessor: 25 | """ 26 | A logits processor to answer multiple choice questions with one of the choices. 27 | A multiple choice question is like: 28 | I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? 29 | 0. Camera 30 | 1. Screen resolution 31 | 2. Operating System 32 | 3. Battery 33 | The goal is to make LLM generate "3" as an answer. 34 | 35 | 36 | Parameters 37 | ---------- 38 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 39 | choices (List[str]): List of one character answers like A, B, C, D. 40 | delimiter (str): One character delimiter that comes after the choices like 1. or 2-. 41 | boost_first_words (float): Nonzero values add choices' first tokens' logits to boost performance. 42 | Especially useful for the models which have difficulty associating the choice with its text. 43 | """ 44 | def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None, 45 | delimiter: str = ".", boost_first_words: float = 0.0): 46 | if choices is None: 47 | choices = ["1", "2", "3", "4"] 48 | 49 | self.new_line_tokens = get_new_line_tokens(tokenizer) 50 | self.delimiter_token = text_to_token(tokenizer, delimiter, last=False) 51 | self.choice_tokens = [text_to_token(tokenizer, choice, last=False) for choice in choices] 52 | self.boost_first_words = boost_first_words 53 | self.very_large_number = 999 54 | 55 | def __call__(self, req_ids_batch: List[int], logits_batch: List[torch.Tensor], 56 | ids_batch: List[List[List[int]]], stream_ptr, 57 | client_ids_batch: List[Optional[int]]): 58 | 59 | if self.boost_first_words: 60 | with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): 61 | ids_batch = torch.LongTensor(ids_batch).to(logits_batch.device, non_blocking=True) 62 | 63 | for row_ind in range(ids_batch.shape[0]): 64 | if self.boost_first_words: 65 | choice = 0 66 | 67 | first_tokens = [] 68 | for i in range(len(ids_batch[row_ind]) - 3): 69 | # A choice is like "\nA) hair dryer", where first token is "hair" 70 | choice_starts = ( 71 | (ids_batch[row_ind, i].item() in self.new_line_tokens) and 72 | (ids_batch[row_ind, i + 1] == self.choice_tokens[choice]) and 73 | (ids_batch[row_ind, i + 2] == self.delimiter_token) 74 | ) 75 | 76 | if choice_starts: 77 | first_tokens.append(ids_batch[row_ind, i + 3]) 78 | choice += 1 79 | 80 | if choice >= len(self.choice_tokens): 81 | break 82 | 83 | boost = self.boost_first_words * logits_batch[:, row_ind, first_tokens] 84 | logits_batch[:, row_ind, self.choice_tokens[:len(first_tokens)]] += boost 85 | 86 | with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)): 87 | logits_batch[:, :, self.choice_tokens] += self.very_large_number 88 | -------------------------------------------------------------------------------- /logits_processor_zoo/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from transformers import PreTrainedTokenizer 19 | from typing import List 20 | import torch 21 | 22 | 23 | def text_to_token(tokenizer: PreTrainedTokenizer, text: str, last: bool): 24 | tokens = tokenizer.encode(text, add_special_tokens=False) 25 | 26 | if not last and len(tokens) > 2: 27 | # Usually the first token indicates the beginning, and the second token is our main token 28 | raise Exception(f"Can't convert {text} to token. It has {len(tokens)} tokens.") 29 | 30 | return tokens[-1] 31 | 32 | 33 | def get_new_line_tokens(tokenizer: PreTrainedTokenizer): 34 | new_line_tokens = [token for token in tokenizer.get_vocab().values() 35 | if tokenizer.decode(token).endswith("\n")] 36 | 37 | return set(new_line_tokens) 38 | 39 | 40 | def enforce_tokens(scores: torch.Tensor, tokens: List[int]): 41 | choice_scores = scores[tokens].clone() 42 | gap = scores.max() - choice_scores.min() 43 | choice_scores += gap 44 | scores.fill_(scores.min()) 45 | scores[tokens] = choice_scores 46 | return scores 47 | -------------------------------------------------------------------------------- /logits_processor_zoo/vllm/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .generation_length import GenLengthLogitsProcessor 19 | from .cite_prompt import CiteFromPromptLogitsProcessor 20 | from .last_phrase import ForceLastPhraseLogitsProcessor 21 | from .multiple_choice import MultipleChoiceLogitsProcessor 22 | from .trigger_phrase import TriggerPhraseLogitsProcessor 23 | 24 | __all__ = ['GenLengthLogitsProcessor', 'CiteFromPromptLogitsProcessor', 'ForceLastPhraseLogitsProcessor', 25 | 'MultipleChoiceLogitsProcessor', 'TriggerPhraseLogitsProcessor'] 26 | -------------------------------------------------------------------------------- /logits_processor_zoo/vllm/cite_prompt.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import List, Union 19 | import torch 20 | from transformers import PreTrainedTokenizer, AutoTokenizer 21 | 22 | 23 | class CiteFromPromptLogitsProcessor: 24 | """ 25 | A logits processor which boosts or diminishes the likelihood of tokens present in the prompt (and optionally 26 | EOS token) to encourage the model to generate tokens similar to those seen in the prompt or vice versa. 27 | 28 | Parameters 29 | ---------- 30 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 31 | boost_factor (float): A factor to boost the likelihood of the tokens from the prompt. 32 | Negative values are used for the opposite effect. 33 | boost_eos (bool, optional): If True, boosts EOS token too. 34 | conditional_boost_factor (float, optional): A factor to boost the likelihood of the tokens based on previous token. 35 | """ 36 | def __init__(self, tokenizer: Union[PreTrainedTokenizer, str], boost_factor: float = 1.0, boost_eos: bool = True, 37 | conditional_boost_factor: float = 0.0): 38 | self.tokenizer = tokenizer 39 | if isinstance(self.tokenizer, str): 40 | self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer) 41 | 42 | self.boost_factor = boost_factor 43 | self.eos_token_id = self.tokenizer.eos_token_id 44 | self.boost_eos = boost_eos 45 | self.conditional_boost_factor = conditional_boost_factor 46 | 47 | def clone(self): 48 | return CiteFromPromptLogitsProcessor(self.tokenizer, self.boost_factor, self.boost_eos, 49 | self.conditional_boost_factor) 50 | 51 | def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor: 52 | tokens = set(prompt_tokens_ids) 53 | if self.boost_eos: 54 | tokens.add(self.eos_token_id) 55 | 56 | tokens = [t for t in tokens if t < scores.shape[0]] 57 | scores[tokens] += self.boost_factor 58 | 59 | if (self.conditional_boost_factor != 0) and (len(past_token_ids) > 0): 60 | tokens = set() 61 | last_token = past_token_ids[-1] 62 | for i in range(len(prompt_tokens_ids) - 1): 63 | if (prompt_tokens_ids[i] == last_token) and (prompt_tokens_ids[i + 1] < scores.shape[0]): 64 | tokens.add(prompt_tokens_ids[i + 1]) 65 | scores[list(tokens)] += self.conditional_boost_factor 66 | 67 | return scores 68 | -------------------------------------------------------------------------------- /logits_processor_zoo/vllm/generation_length.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import List, Union 19 | import torch 20 | from transformers import PreTrainedTokenizer, AutoTokenizer 21 | from logits_processor_zoo.utils import text_to_token 22 | 23 | 24 | class GenLengthLogitsProcessor: 25 | """ 26 | A logits processor that adjusts the likelihood of the end-of-sequence (EOS) token 27 | based on the length of the generated sequence, encouraging or discouraging shorter answers. 28 | 29 | Parameters 30 | ---------- 31 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 32 | boost_factor (float): A factor to boost the likelihood of the EOS token as the sequence length increases. 33 | Suggested value range is [-1.0, 1.0]. Negative values are used for the opposite effect. 34 | p (int, optional): The power to which the token count is raised when computing the boost value. Default is 2. 35 | complete_sentences (bool, optional): If True, boosts EOS token likelihood only when the last token is a full stop 36 | or a new line. Default is False. 37 | boost_token_str (str, optional): A string to be tokenized and used instead of EOS. Especially useful for . 38 | """ 39 | def __init__(self, tokenizer: Union[PreTrainedTokenizer, str], boost_factor: float, 40 | p: int = 2, complete_sentences: bool = False, boost_token_str: str = None): 41 | 42 | self.tokenizer = tokenizer 43 | if isinstance(self.tokenizer, str): 44 | self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer) 45 | 46 | self.boost_token = self.tokenizer.eos_token_id 47 | self.boost_token_str = boost_token_str 48 | if boost_token_str is not None: 49 | self.boost_token = text_to_token(self.tokenizer, boost_token_str, last=False) 50 | self.boost_factor = boost_factor 51 | self.p = p 52 | self.full_stop_token = text_to_token(self.tokenizer, "It is a sentence.", last=True) 53 | self.new_line_token = text_to_token(self.tokenizer, "It is a new line\n", last=True) 54 | self.complete_sentences = complete_sentences 55 | 56 | def clone(self): 57 | return GenLengthLogitsProcessor(self.tokenizer, self.boost_factor, self.p, 58 | self.complete_sentences, self.boost_token_str) 59 | 60 | def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor: 61 | gen_length = len(past_token_ids) 62 | 63 | boost_val = 0 64 | if not (self.boost_token in past_token_ids): 65 | boost_val = self.boost_factor * (gen_length ** self.p) / (10 ** self.p) 66 | 67 | if self.complete_sentences and gen_length > 0: 68 | enabled = (past_token_ids[-1] == self.full_stop_token) | (past_token_ids[-1] == self.new_line_token) 69 | scores[self.boost_token] += enabled * boost_val 70 | else: 71 | scores[self.boost_token] += boost_val 72 | 73 | return scores 74 | -------------------------------------------------------------------------------- /logits_processor_zoo/vllm/last_phrase.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from transformers import PreTrainedTokenizer, AutoTokenizer 19 | from typing import List, Union 20 | import torch 21 | from logits_processor_zoo.utils import enforce_tokens 22 | 23 | 24 | class ForceLastPhraseLogitsProcessor: 25 | """ 26 | A logits processor which forces LLMs to use the given phrase before they finalize their answers. 27 | Most common use cases can be providing references, thanking user with context etc. 28 | 29 | Parameters 30 | ---------- 31 | phrase (str): The phrase to be generated by LLM before the end of its speech. 32 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 33 | """ 34 | def __init__(self, phrase: str, tokenizer: Union[PreTrainedTokenizer, str]): 35 | self.tokenizer = tokenizer 36 | if isinstance(self.tokenizer, str): 37 | self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer) 38 | 39 | self.eos_token_id = self.tokenizer.eos_token_id 40 | self.phrase_tokens = self.tokenizer.encode(phrase, add_special_tokens=False) 41 | self._reset() 42 | self.phrase = phrase 43 | 44 | # LogitsProcessor can contain a clone attribute to deep copy it 45 | # https://github.com/vllm-project/vllm/blob/19dcc02a72e3ed52e3bf95aae44ea1f40ce42ea0/vllm/sampling_params.py#L537-L550 46 | def clone(self): 47 | return ForceLastPhraseLogitsProcessor(self.phrase, self.tokenizer) 48 | 49 | def _reset(self): 50 | self.index = 0 51 | 52 | def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor: 53 | if not past_token_ids: # new generation 54 | self._reset() 55 | 56 | if scores.argmax() == self.eos_token_id and self.index == 0: 57 | scores = enforce_tokens(scores, [self.phrase_tokens[self.index]]) 58 | self.index += 1 59 | elif len(self.phrase_tokens) > self.index > 0: 60 | scores = enforce_tokens(scores, [self.phrase_tokens[self.index]]) 61 | self.index += 1 62 | 63 | return scores 64 | -------------------------------------------------------------------------------- /logits_processor_zoo/vllm/multiple_choice.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from transformers import PreTrainedTokenizer, AutoTokenizer 19 | from typing import List, Union 20 | import torch 21 | from logits_processor_zoo.utils import text_to_token, get_new_line_tokens, enforce_tokens 22 | 23 | 24 | class MultipleChoiceLogitsProcessor: 25 | """ 26 | A logits processor to answer multiple choice questions with one of the choices. 27 | A multiple choice question is like: 28 | I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? 29 | 0. Camera 30 | 1. Screen resolution 31 | 2. Operating System 32 | 3. Battery 33 | The goal is to make LLM generate "3" as an answer. 34 | 35 | 36 | Parameters 37 | ---------- 38 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 39 | choices (List[str]): List of one character answers like A, B, C, D. 40 | delimiter (str): One character delimiter that comes after the choices like 1. or 2-. 41 | boost_first_words (float): Nonzero values add choices' first tokens' logits to boost performance. 42 | Especially useful for the models which have difficulty associating the choice with its text. 43 | """ 44 | def __init__(self, tokenizer: Union[PreTrainedTokenizer, str], choices: List[str] = None, 45 | delimiter: str = ".", boost_first_words: float = 0.0): 46 | self.tokenizer = tokenizer 47 | if isinstance(self.tokenizer, str): 48 | self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer) 49 | 50 | self.choices = choices 51 | self.delimiter = delimiter 52 | if choices is None: 53 | choices = ["1", "2", "3", "4"] 54 | 55 | self.new_line_token = get_new_line_tokens(self.tokenizer) 56 | self.delimiter_token = text_to_token(self.tokenizer, delimiter, last=False) 57 | self.choice_tokens = [text_to_token(self.tokenizer, choice, last=False) for choice in choices] 58 | self.boost_first_words = boost_first_words 59 | 60 | def clone(self): 61 | return MultipleChoiceLogitsProcessor(self.tokenizer, self.choices, self.delimiter, self.boost_first_words) 62 | 63 | def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor: 64 | 65 | if self.boost_first_words: 66 | choice = 0 67 | 68 | first_tokens = [] 69 | for i in range(len(prompt_tokens_ids) - 3): 70 | # A choice is like "\nA) hair dryer", where first token is "hair" 71 | choice_starts = ( 72 | (prompt_tokens_ids[i] in self.new_line_token) and 73 | (prompt_tokens_ids[i + 1] == self.choice_tokens[choice]) and 74 | (prompt_tokens_ids[i + 2] == self.delimiter_token) 75 | ) 76 | 77 | if choice_starts: 78 | first_tokens.append(prompt_tokens_ids[i + 3]) 79 | choice += 1 80 | 81 | if choice >= len(self.choice_tokens): 82 | break 83 | 84 | scores[self.choice_tokens[:len(first_tokens)]] += self.boost_first_words * scores[first_tokens] 85 | 86 | scores = enforce_tokens(scores, self.choice_tokens) 87 | return scores 88 | -------------------------------------------------------------------------------- /logits_processor_zoo/vllm/trigger_phrase.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from transformers import PreTrainedTokenizer, AutoTokenizer 19 | from typing import List, Union 20 | import torch 21 | from logits_processor_zoo.utils import text_to_token, enforce_tokens 22 | 23 | 24 | class TriggerPhraseLogitsProcessor: 25 | """ 26 | A logits processor which triggers phrases when it encounters a given token. 27 | 28 | Parameters 29 | ---------- 30 | phrase (str): The phrase to be generated by LLM when it encounters the trigger token. 31 | trigger_token_phrase (str): One token phrase in string to trigger phrases. 32 | tokenizer (PreTrainedTokenizer): The tokenizer used by the LLM. 33 | trigger_count (int): How many times the phrase will be triggered. 34 | trigger_after (bool): Whether the phrase is written after the trigger token or instead of the trigger token. 35 | """ 36 | def __init__(self, phrase: str, trigger_token_phrase: str, tokenizer: Union[PreTrainedTokenizer, str], 37 | trigger_count: int = 1, trigger_after: bool = False): 38 | self.tokenizer = tokenizer 39 | if isinstance(self.tokenizer, str): 40 | self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer) 41 | 42 | self.phrase = phrase 43 | self.trigger_token_phrase = trigger_token_phrase 44 | self.trigger_count = trigger_count 45 | self.trigger_token = text_to_token(self.tokenizer, trigger_token_phrase, last=False) 46 | self.phrase_tokens = self.tokenizer.encode(phrase, add_special_tokens=False) 47 | self.initial_trigger_count = trigger_count 48 | self.trigger_after = trigger_after 49 | self._reset() 50 | 51 | def clone(self): 52 | return TriggerPhraseLogitsProcessor(self.phrase, self.trigger_token_phrase, self.tokenizer, 53 | self.initial_trigger_count, self.trigger_after) 54 | 55 | def _reset(self): 56 | self.index = -1 57 | self.trigger_count = self.initial_trigger_count 58 | 59 | def __call__(self, prompt_tokens_ids: List[int], past_token_ids: List[int], scores: torch.Tensor) -> torch.Tensor: 60 | if not past_token_ids: # new generation 61 | self._reset() 62 | 63 | if self.trigger_count <= 0: 64 | return scores 65 | 66 | if scores.argmax() == self.trigger_token and self.index == -1: 67 | self.index = 0 68 | if not self.trigger_after: 69 | scores = enforce_tokens(scores, [self.phrase_tokens[self.index]]) 70 | self.index += 1 71 | elif len(self.phrase_tokens) > self.index >= 0: 72 | scores = enforce_tokens(scores, [self.phrase_tokens[self.index]]) 73 | self.index += 1 74 | 75 | if len(self.phrase_tokens) == self.index: # phrase completed, reset for next trigger 76 | self.index = -1 77 | self.trigger_count -= 1 78 | 79 | return scores 80 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "logits-processor-zoo" 3 | version = "0.1.9" 4 | description = "A collection of LogitsProcessors to customize and enhance LLM behavior for specific tasks." 5 | authors = ["Ahmet Erdem", "Ivan Sorokin", "Maximilian Jeblick", "Darragh Hanley", "David Austin"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.10" 10 | torch = "*" 11 | transformers = ">=4.41.2" 12 | accelerate = ">=0.26.1" 13 | vllm = { version = ">=0.5.0.post1", optional = true } 14 | 15 | [tool.poetry.extras] 16 | vllm = ["vllm"] 17 | 18 | 19 | [build-system] 20 | requires = ["poetry-core"] 21 | build-backend = "poetry.core.masonry.api" 22 | 23 | [tool.flake8] 24 | max-line-length = 120 25 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import pytest 19 | import torch 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList 21 | 22 | 23 | class LLMRunner: 24 | def __init__(self, model_name='google/gemma-1.1-2b-it'): 25 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 26 | self.tokenizer.pad_token = self.tokenizer.eos_token 27 | self.tokenizer.padding_side = "left" 28 | 29 | self.model = AutoModelForCausalLM.from_pretrained( 30 | model_name, 31 | torch_dtype=torch.float16, 32 | device_map="auto", 33 | trust_remote_code=True 34 | ) 35 | 36 | def generate_response(self, prompts, logits_processor_list=None, max_new_tokens=1000): 37 | if logits_processor_list is None: 38 | logits_processor_list = [] 39 | 40 | input_ids = self.tokenizer(prompts, return_tensors='pt', padding=True)["input_ids"] 41 | 42 | out_ids = self.model.generate(input_ids.to(self.model.device), 43 | max_new_tokens=max_new_tokens, min_new_tokens=1, 44 | logits_processor=LogitsProcessorList(logits_processor_list) 45 | ) 46 | 47 | gen_output = self.tokenizer.batch_decode(out_ids, skip_special_tokens=True, 48 | clean_up_tokenization_spaces=False) 49 | 50 | return [out[len(prompt):].strip() for prompt, out in zip(prompts, gen_output)] 51 | 52 | 53 | @pytest.fixture(scope='session') 54 | def llm_runner(): 55 | return LLMRunner(model_name="MaxJeblick/llama2-0b-unit-test") 56 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from logits_processor_zoo.utils import text_to_token, get_new_line_tokens, enforce_tokens 2 | import torch 3 | 4 | 5 | def test_text_to_token(llm_runner): 6 | assert text_to_token(llm_runner.tokenizer, ",", last=False) == 1919 7 | assert text_to_token(llm_runner.tokenizer, "apple, orange,", last=True) == 29892 8 | assert text_to_token(llm_runner.tokenizer, "apple, orange\n", last=True) == 13 9 | 10 | try: 11 | token = text_to_token(llm_runner.tokenizer, "apple, orange,", last=False) 12 | except Exception: 13 | token = -1 14 | 15 | assert token == -1 16 | 17 | 18 | def test_get_new_line_tokens(llm_runner): 19 | assert get_new_line_tokens(llm_runner.tokenizer) == {13} 20 | 21 | 22 | def test_enforce_tokens(): 23 | scores = torch.FloatTensor([0.1, -0.4, -0.2, -0.6, 1.1]) 24 | tokens = [1, 2] 25 | 26 | scores = enforce_tokens(scores, tokens) 27 | _, top2_tokens = torch.topk(scores, k=2) 28 | assert torch.equal(top2_tokens, torch.tensor([2, 1])) 29 | -------------------------------------------------------------------------------- /tests/transformers/test_cite_prompt.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from logits_processor_zoo.transformers import CiteFromPromptLogitsProcessor 19 | 20 | 21 | def test_cite_from_prompt_logits_processor(llm_runner): 22 | example_prompts = [ 23 | "Please describe what macaques are.", 24 | "Tell me a story about a kid lost in forest." 25 | ] 26 | 27 | default_gen_output = llm_runner.generate_response(example_prompts, max_new_tokens=10) 28 | 29 | logits_processors = [CiteFromPromptLogitsProcessor(llm_runner.tokenizer, boost_factor=50.0, 30 | conditional_boost_factor=50.0)] 31 | processed_gen_output = llm_runner.generate_response(example_prompts, logits_processors, max_new_tokens=10) 32 | 33 | for prompt, default_out, processed_out in zip(example_prompts, default_gen_output, processed_gen_output): 34 | prompt_tokens = set(prompt.split()) 35 | default_out_tokens = set(default_out.split()) 36 | processed_out_tokens = set(processed_out.split()) 37 | 38 | default_shared_tokens = prompt_tokens.intersection(default_out_tokens) 39 | processed_shared_tokens = prompt_tokens.intersection(processed_out_tokens) 40 | 41 | assert len(processed_shared_tokens) > len(default_shared_tokens) 42 | -------------------------------------------------------------------------------- /tests/transformers/test_generation_length.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from logits_processor_zoo.transformers import GenLengthLogitsProcessor 19 | 20 | 21 | def test_gen_length_logits_processor(llm_runner): 22 | example_prompts = [ 23 | "Please describe what macaques are.", 24 | "Tell me a story about a kid lost in forest." 25 | ] 26 | 27 | default_gen_output = llm_runner.generate_response(example_prompts) 28 | 29 | logits_processors = [GenLengthLogitsProcessor(llm_runner.tokenizer, boost_factor=1.0)] 30 | processed_gen_output = llm_runner.generate_response(example_prompts, logits_processors) 31 | 32 | assert all(len(p1) > len(p2) for p1, p2 in zip(default_gen_output, processed_gen_output)) 33 | 34 | processed_gen_output_repeat = llm_runner.generate_response(example_prompts, logits_processors) 35 | assert all(p1 == p2 for p1, p2 in zip(processed_gen_output, processed_gen_output_repeat)) 36 | -------------------------------------------------------------------------------- /tests/transformers/test_last_phrase.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from logits_processor_zoo.transformers import ForceLastPhraseLogitsProcessor, GenLengthLogitsProcessor 19 | 20 | 21 | def test_cite_from_prompt_logits_processor(llm_runner): 22 | example_prompts = [ 23 | "Please describe what macaques are.", 24 | "Tell me a story about a kid lost in forest." 25 | ] 26 | 27 | phrase = "This is a test phrase." 28 | 29 | logits_processors = [GenLengthLogitsProcessor(llm_runner.tokenizer, boost_factor=1.0), 30 | ForceLastPhraseLogitsProcessor(phrase, llm_runner.tokenizer, batch_size=len(example_prompts))] 31 | processed_gen_output = llm_runner.generate_response(example_prompts, logits_processors, max_new_tokens=100) 32 | 33 | assert all((phrase in out) for out in processed_gen_output) 34 | 35 | processed_gen_output_repeat = llm_runner.generate_response(example_prompts, logits_processors, max_new_tokens=100) 36 | assert all(p1 == p2 for p1, p2 in zip(processed_gen_output, processed_gen_output_repeat)) 37 | -------------------------------------------------------------------------------- /tests/transformers/test_multiple_choice.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from logits_processor_zoo.transformers import MultipleChoiceLogitsProcessor 19 | 20 | 21 | def test_cite_from_prompt_logits_processor(llm_runner): 22 | example_prompts = [ 23 | """ 24 | I am getting a lot of calls during the day. What is more important for me to consider when I buy a new phone? 25 | a) Camera 26 | b) Screen resolution 27 | c) Operating System 28 | d) Battery 29 | 30 | Answer: 31 | """, 32 | 33 | """ 34 | Which user review doesn't belong to a summer dress? 35 | a) Looks good 36 | b) Keeps warm 37 | c) Too long 38 | d) Liked the color 39 | 40 | Answer: 41 | """ 42 | ] 43 | 44 | choices = ["a", "b", "c", "d"] 45 | logits_processors = [MultipleChoiceLogitsProcessor(llm_runner.tokenizer, choices=choices, delimiter=")")] 46 | processed_gen_output = llm_runner.generate_response(example_prompts, logits_processors, max_new_tokens=1) 47 | 48 | assert all((out in choices) for out in processed_gen_output) 49 | 50 | example_prompts = [prompt.replace("a)", "1.").replace("b)", "2.").replace("c)", "3.").replace("d)", "4.") 51 | for prompt in example_prompts] 52 | 53 | choices = ["1", "2", "3", "4"] 54 | logits_processors = [MultipleChoiceLogitsProcessor(llm_runner.tokenizer, choices=choices, delimiter=".", 55 | boost_first_words=1.0)] 56 | processed_gen_output = llm_runner.generate_response(example_prompts, logits_processors, max_new_tokens=1) 57 | 58 | assert all((out in choices) for out in processed_gen_output) 59 | --------------------------------------------------------------------------------