├── .github └── workflows │ └── python.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── launch.json └── settings.json ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── pyproject.toml ├── src └── openai_messages_token_helper │ ├── __init__.py │ ├── function_format.py │ ├── images_helper.py │ ├── message_builder.py │ ├── model_helper.py │ └── py.typed └── tests ├── __init__.py ├── conftest.py ├── functions.py ├── image_large.png ├── image_messages.py ├── messages.py ├── test_imageshelper.py ├── test_messagebuilder.py ├── test_modelhelper.py ├── verify_functions.py └── verify_openai.py /.github/workflows/python.yaml: -------------------------------------------------------------------------------- 1 | name: Python checks 2 | 3 | on: 4 | push: 5 | branches: [ main, master ] 6 | pull_request: 7 | branches: [ main, master ] 8 | 9 | jobs: 10 | build: 11 | name: Test with Python ${{ matrix.python_version }} 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | python_version: ["3.9", "3.10", "3.11", "3.12"] 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Set up Python 3 20 | uses: actions/setup-python@v3 21 | with: 22 | python-version: ${{ matrix.python_version }} 23 | - name: Install dependencies 24 | run: | 25 | python3 -m pip install --upgrade pip 26 | python3 -m pip install -e '.[dev]' 27 | - name: Lint with ruff 28 | run: ruff check . 29 | - name: Check formatting with black 30 | run: black . --check --verbose 31 | - name: Run unit tests 32 | run: | 33 | python3 -m pytest -s -vv --cov --cov-fail-under=97 34 | - name: Run type checks 35 | run: python3 -m mypy . 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/astral-sh/ruff-pre-commit 9 | rev: v0.9.0 10 | hooks: 11 | - id: ruff 12 | - repo: https://github.com/psf/black 13 | rev: 24.10.0 14 | hooks: 15 | - id: black 16 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Debug Tests", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "purpose": ["debug-test"], 13 | "console": "integratedTerminal", 14 | "justMyCode": false 15 | } 16 | ] 17 | } 18 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "tests" 4 | ], 5 | "python.testing.unittestEnabled": false, 6 | "python.testing.pytestEnabled": true, 7 | "files.exclude": { 8 | ".coverage": true, 9 | ".pytest_cache": true, 10 | "__pycache__": true, 11 | ".ruff_cache": true, 12 | ".mypy_cache": true, 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | ## [0.1.11] - Jan 10, 2025 6 | 7 | - When no "detail" is provided for an "image_url" message part, "auto" is now assumed. 8 | 9 | ## [0.1.10] - Aug 7, 2024 10 | 11 | - Add additional OpenAI.com model names to the `get_token_limit` function. 12 | 13 | ## [0.1.9] - Aug 7, 2024 14 | 15 | - Add gpt-4o-mini support, by adding a 33.3x multiplier to the token cost. 16 | 17 | ## [0.1.8] - Aug 3, 2024 18 | 19 | - Fix the type for the tool_choice param to be inclusive of "auto" and other options. 20 | 21 | ## [0.1.7] - Aug 3, 2024 22 | 23 | - Fix bug where you couldn't pass in example tool calls in `few_shots` to `build_messages`. 24 | 25 | ## [0.1.6] - Aug 2, 2024 26 | 27 | - Fix bug where you couldn't pass in `tools` and `default_to_cl100k` to True with a non-OpenAI model. 28 | 29 | ## [0.1.5] - June 4, 2024 30 | 31 | - Remove spurious `print` call when counting tokens for function calling. 32 | 33 | ## [0.1.4] - May 14, 2024 34 | 35 | - Add support and tests for gpt-4o, which has a different tokenizer. 36 | 37 | ## [0.1.3] - May 2, 2024 38 | 39 | - Use openai type annotations for more precise type hints, and add a typing test. 40 | 41 | ## [0.1.2] - May 2, 2024 42 | 43 | - Add `py.typed` file so that mypy can find the type hints in this package. 44 | 45 | ## [0.1.0] - May 2, 2024 46 | 47 | - Add `count_tokens_for_system_and_tools` to count tokens for system message and tools. You should count the tokens for both together, since the token count for tools varies based off whether a system message is provided. 48 | - Updated `build_messages` to allow for `tools` and `tool_choice` to be passed in. 49 | - Breaking change: Changed `new_user_message` to `new_user_content` in `build_messages` for clarity. 50 | 51 | ## [0.0.6] - April 24, 2024 52 | 53 | - Add keyword argument `fallback_to_default` to `build_messages` function to allow for defaulting to the CL100k token encoder and minimum GPT token limit if the model is not found. 54 | - Fixed usage of `past_messages` argument of `build_messages` to not skip the last past message. (New user message should *not* be passed in) 55 | 56 | ## [0.0.5] - April 24, 2024 57 | 58 | - Add keyword argument `default_to_cl100k` to `count_tokens_for_message` function to allow for defaulting to the CL100k token limit if the model is not found. 59 | - Add keyword argument `default_to_minimum` to `get_token_limit` function to allow for defaulting to the minimum token limit if the model is not found. 60 | 61 | ## [0.0.4] - April 21, 2024 62 | 63 | - Rename to openai-messages-token-helper from llm-messages-token-helper to reflect library's current OpenAI focus. 64 | 65 | ## [0.0.3] - April 21, 2024 66 | 67 | - Fix for `count_tokens_for_message` function to match OpenAI output precisely, particularly for calls with images to GPT-4 vision. 68 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | ## Development 4 | 5 | Install the project dependencies: 6 | 7 | ```sh 8 | python3 -m pip install -e '.[dev]' 9 | pre-commit install 10 | ``` 11 | 12 | Run the tests: 13 | 14 | ```sh 15 | python3 -m pytest 16 | ``` 17 | 18 | ## Publishing 19 | 20 | 1. Update the CHANGELOG with description of changes 21 | 22 | 2. Update the version number in pyproject.toml 23 | 24 | 3. Push the changes to the main branch 25 | 26 | 4. Publish to PyPi: 27 | 28 | ```shell 29 | export FLIT_USERNAME=__token__ 30 | export FLIT_PASSWORD= 31 | flit publish 32 | ``` 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Brian Okken 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # openai-messages-token-helper 2 | 3 | A helper library for estimating tokens used by messages and building messages lists that fit within the token limits of a model. 4 | Currently designed to work with the OpenAI GPT models (including GPT-4 turbo with vision). 5 | Uses the tiktoken library for tokenizing text and the Pillow library for image-related calculations. 6 | 7 | ## Installation 8 | 9 | Install the package: 10 | 11 | ```sh 12 | python3 -m pip install openai-messages-token-helper 13 | ``` 14 | 15 | ## Usage 16 | 17 | The library provides the following functions: 18 | 19 | * [`build_messages`](#build_messages) 20 | * [`count_tokens_for_message`](#count_tokens_for_message) 21 | * [`count_tokens_for_image`](#count_tokens_for_image) 22 | * [`get_token_limit`](#get_token_limit) 23 | 24 | ### `build_messages` 25 | 26 | Build a list of messages for a chat conversation, given the system prompt, new user message, 27 | and past messages. The function will truncate the history of past messages if necessary to 28 | stay within the token limit. 29 | 30 | Arguments: 31 | 32 | * `model` (`str`): The model name to use for token calculation, like gpt-3.5-turbo. 33 | * `system_prompt` (`str`): The initial system prompt message. 34 | * `tools` (`List[openai.types.chat.ChatCompletionToolParam]`): (Optional) The tools that will be used in the conversation. These won't be part of the final returned messages, but they will be used to calculate the token count. 35 | * `tool_choice` (`openai.types.chat.ChatCompletionToolChoiceOptionParam`): (Optional) The tool choice that will be used in the conversation. This won't be part of the final returned messages, but it will be used to calculate the token count. 36 | * `new_user_content` (`str | List[openai.types.chat.ChatCompletionContentPartParam]`): (Optional) The content of new user message to append. 37 | * `past_messages` (`list[openai.types.chat.ChatCompletionMessageParam]`): (Optional) The list of past messages in the conversation. 38 | * `few_shots` (`list[openai.types.chat.ChatCompletionMessageParam]`): (Optional) A few-shot list of messages to insert after the system prompt. 39 | * `max_tokens` (`int`): (Optional) The maximum number of tokens allowed for the conversation. 40 | * `fallback_to_default` (`bool`): (Optional) Whether to fallback to default model/token limits if model is not found. Defaults to `False`. 41 | 42 | 43 | Returns: 44 | 45 | * `list[openai.types.chat.ChatCompletionMessageParam]` 46 | 47 | Example: 48 | 49 | ```python 50 | from openai_messages_token_helper import build_messages 51 | 52 | messages = build_messages( 53 | model="gpt-35-turbo", 54 | system_prompt="You are a bot.", 55 | new_user_content="That wasn't a good poem.", 56 | past_messages=[ 57 | { 58 | "role": "user", 59 | "content": "Write me a poem", 60 | }, 61 | { 62 | "role": "assistant", 63 | "content": "Tuna tuna I love tuna", 64 | }, 65 | ], 66 | few_shots=[ 67 | { 68 | "role": "user", 69 | "content": "Write me a poem", 70 | }, 71 | { 72 | "role": "assistant", 73 | "content": "Tuna tuna is the best", 74 | }, 75 | ] 76 | ) 77 | ``` 78 | 79 | ### `count_tokens_for_message` 80 | 81 | Counts the number of tokens in a message. 82 | 83 | Arguments: 84 | 85 | * `model` (`str`): The model name to use for token calculation, like gpt-3.5-turbo. 86 | * `message` (`openai.types.chat.ChatCompletionMessageParam`): The message to count tokens for. 87 | * `default_to_cl100k` (`bool`): Whether to default to the CL100k token limit if the model is not found. 88 | 89 | Returns: 90 | 91 | * `int`: The number of tokens in the message. 92 | 93 | Example: 94 | 95 | ```python 96 | from openai_messages_token_helper import count_tokens_for_message 97 | 98 | message = { 99 | "role": "user", 100 | "content": "Hello, how are you?", 101 | } 102 | model = "gpt-4" 103 | num_tokens = count_tokens_for_message(model, message) 104 | ``` 105 | 106 | ### `count_tokens_for_image` 107 | 108 | Count the number of tokens for an image sent to GPT-4-vision, in base64 format. 109 | 110 | Arguments: 111 | 112 | * `image` (`str`): The base64-encoded image. 113 | 114 | Returns: 115 | 116 | * `int`: The number of tokens used up for the image. 117 | 118 | Example: 119 | 120 | ```python 121 | 122 | Count the number of tokens for an image sent to GPT-4-vision: 123 | 124 | ```python 125 | from openai_messages_token_helper import count_tokens_for_image 126 | 127 | image = "..." 128 | num_tokens = count_tokens_for_image(image) 129 | ``` 130 | 131 | ### `get_token_limit` 132 | 133 | Get the token limit for a given GPT model name (OpenAI.com or Azure OpenAI supported). 134 | 135 | Arguments: 136 | 137 | * `model` (`str`): The model name to use for token calculation, like gpt-3.5-turbo (OpenAI.com) or gpt-35-turbo (Azure). 138 | * `default_to_minimum` (`bool`): Whether to default to the minimum token limit if the model is not found. 139 | 140 | Returns: 141 | 142 | * `int`: The token limit for the model. 143 | 144 | Example: 145 | 146 | ```python 147 | from openai_messages_token_helper import get_token_limit 148 | 149 | model = "gpt-4" 150 | max_tokens = get_token_limit(model) 151 | ``` 152 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "openai-messages-token-helper" 3 | description = "A helper library for estimating tokens used by messages sent through OpenAI Chat Completions API." 4 | version = "0.1.11" 5 | authors = [{name = "Pamela Fox"}] 6 | requires-python = ">=3.9" 7 | readme = "README.md" 8 | license = {file = "LICENSE"} 9 | dependencies = [ 10 | "openai", 11 | "tiktoken", 12 | "pillow" 13 | ] 14 | classifiers = [ 15 | "License :: OSI Approved :: MIT License", 16 | "Programming Language :: Python", 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Framework :: Pytest" 22 | ] 23 | 24 | [project.urls] 25 | Home = "https://github.com/pamelafox/openai-messages-token-helper" 26 | 27 | [project.optional-dependencies] 28 | dev = [ 29 | "pytest", 30 | "pytest-cov", 31 | "pre-commit", 32 | "ruff", 33 | "black", 34 | "flit", 35 | "azure-identity", 36 | "python-dotenv", 37 | "mypy" 38 | ] 39 | 40 | [build-system] 41 | requires = ["flit_core >=3.2,<4"] 42 | build-backend = "flit_core.buildapi" 43 | 44 | [tool.ruff] 45 | line-length = 120 46 | target-version = "py39" 47 | output-format = "full" 48 | 49 | [tool.ruff.lint] 50 | select = ["E", "F", "I", "UP"] 51 | ignore = ["D203", "E501"] 52 | 53 | [tool.black] 54 | line-length = 120 55 | target-version = ["py39"] 56 | 57 | [tool.pytest.ini_options] 58 | addopts = "-ra --cov" 59 | 60 | [tool.coverage.report] 61 | show_missing = true 62 | -------------------------------------------------------------------------------- /src/openai_messages_token_helper/__init__.py: -------------------------------------------------------------------------------- 1 | from .images_helper import count_tokens_for_image 2 | from .message_builder import build_messages 3 | from .model_helper import count_tokens_for_message, count_tokens_for_system_and_tools, get_token_limit 4 | 5 | __all__ = [ 6 | "build_messages", 7 | "count_tokens_for_message", 8 | "count_tokens_for_image", 9 | "get_token_limit", 10 | "count_tokens_for_system_and_tools", 11 | ] 12 | -------------------------------------------------------------------------------- /src/openai_messages_token_helper/function_format.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/forestwanglin/openai-java/blob/main/jtokkit/src/main/java/xyz/felh/openai/jtokkit/utils/TikTokenUtils.java 2 | 3 | 4 | def format_function_definitions(tools): 5 | lines = [] 6 | lines.append("namespace functions {") 7 | lines.append("") 8 | for tool in tools: 9 | function = tool.get("function") 10 | if function_description := function.get("description"): 11 | lines.append(f"// {function_description}") 12 | function_name = function.get("name") 13 | parameters = function.get("parameters", {}) 14 | properties = parameters.get("properties") 15 | if properties and properties.keys(): 16 | lines.append(f"type {function_name} = (_: {{") 17 | lines.append(format_object_parameters(parameters, 0)) 18 | lines.append("}) => any;") 19 | else: 20 | lines.append(f"type {function_name} = () => any;") 21 | lines.append("") 22 | lines.append("} // namespace functions") 23 | return "\n".join(lines) 24 | 25 | 26 | def format_object_parameters(parameters, indent): 27 | properties = parameters.get("properties") 28 | if not properties: 29 | return "" 30 | required_params = parameters.get("required", []) 31 | lines = [] 32 | for key, props in properties.items(): 33 | description = props.get("description") 34 | if description: 35 | lines.append(f"// {description}") 36 | question = "?" 37 | if required_params and key in required_params: 38 | question = "" 39 | lines.append(f"{key}{question}: {format_type(props, indent)},") 40 | return "\n".join([" " * max(0, indent) + line for line in lines]) 41 | 42 | 43 | def format_type(props, indent): 44 | type = props.get("type") 45 | if type == "string": 46 | if "enum" in props: 47 | return " | ".join([f'"{item}"' for item in props["enum"]]) 48 | return "string" 49 | elif type == "array": 50 | # items is required, OpenAI throws an error if it's missing 51 | return f"{format_type(props['items'], indent)}[]" 52 | elif type == "object": 53 | return f"{{\n{format_object_parameters(props, indent + 2)}\n}}" 54 | elif type in ["integer", "number"]: 55 | if "enum" in props: 56 | return " | ".join([f'"{item}"' for item in props["enum"]]) 57 | return "number" 58 | elif type == "boolean": 59 | return "boolean" 60 | elif type == "null": 61 | return "null" 62 | else: 63 | # This is a guess, as an empty string doesn't yield the expected token count 64 | return "any" 65 | -------------------------------------------------------------------------------- /src/openai_messages_token_helper/images_helper.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import math 3 | import re 4 | from fractions import Fraction 5 | from io import BytesIO 6 | from typing import Optional 7 | 8 | from PIL import Image 9 | 10 | 11 | def get_image_dims(image_uri: str) -> tuple[int, int]: 12 | # From https://github.com/openai/openai-cookbook/pull/881/files 13 | if re.match(r"data:image\/\w+;base64", image_uri): 14 | image_uri = re.sub(r"data:image\/\w+;base64,", "", image_uri) 15 | image = Image.open(BytesIO(base64.b64decode(image_uri))) 16 | return image.size 17 | else: 18 | raise ValueError("Image must be a base64 string.") 19 | 20 | 21 | def count_tokens_for_image(image_uri: str, detail: str = "auto", model: Optional[str] = None) -> int: 22 | # From https://github.com/openai/openai-cookbook/pull/881/files 23 | # Based on https://platform.openai.com/docs/guides/vision 24 | multiplier = Fraction(1, 1) 25 | if model == "gpt-4o-mini": 26 | multiplier = Fraction(100, 3) 27 | COST_PER_TILE = 85 * multiplier 28 | LOW_DETAIL_COST = COST_PER_TILE 29 | HIGH_DETAIL_COST_PER_TILE = COST_PER_TILE * 2 30 | 31 | if detail == "auto": 32 | # assume high detail for now 33 | detail = "high" 34 | 35 | if detail == "low": 36 | # Low detail images have a fixed cost 37 | return int(LOW_DETAIL_COST) 38 | elif detail == "high": 39 | # Calculate token cost for high detail images 40 | width, height = get_image_dims(image_uri) 41 | # Check if resizing is needed to fit within a 2048 x 2048 square 42 | if max(width, height) > 2048: 43 | # Resize dimensions to fit within a 2048 x 2048 square 44 | ratio = 2048 / max(width, height) 45 | width = int(width * ratio) 46 | height = int(height * ratio) 47 | # Further scale down to 768px on the shortest side 48 | if min(width, height) > 768: 49 | ratio = 768 / min(width, height) 50 | width = int(width * ratio) 51 | height = int(height * ratio) 52 | # Calculate the number of 512px squares 53 | num_squares = math.ceil(width / 512) * math.ceil(height / 512) 54 | # Calculate the total token cost 55 | total_cost = num_squares * HIGH_DETAIL_COST_PER_TILE + COST_PER_TILE 56 | return math.ceil(total_cost) 57 | else: 58 | # Invalid detail_option 59 | raise ValueError("Invalid value for detail parameter. Use 'low' or 'high'.") 60 | -------------------------------------------------------------------------------- /src/openai_messages_token_helper/message_builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import unicodedata 3 | from collections.abc import Iterable 4 | from typing import Optional, Union 5 | 6 | from openai.types.chat import ( 7 | ChatCompletionAssistantMessageParam, 8 | ChatCompletionContentPartParam, 9 | ChatCompletionMessageParam, 10 | ChatCompletionMessageToolCallParam, 11 | ChatCompletionRole, 12 | ChatCompletionSystemMessageParam, 13 | ChatCompletionToolChoiceOptionParam, 14 | ChatCompletionToolMessageParam, 15 | ChatCompletionToolParam, 16 | ChatCompletionUserMessageParam, 17 | ) 18 | 19 | from .model_helper import count_tokens_for_message, count_tokens_for_system_and_tools, get_token_limit 20 | 21 | 22 | def normalize_content(content: Union[str, Iterable[ChatCompletionContentPartParam], None]): 23 | if content is None: 24 | return None 25 | if isinstance(content, str): 26 | return unicodedata.normalize("NFC", content) 27 | else: 28 | for part in content: 29 | if part["type"] == "text": 30 | part["text"] = unicodedata.normalize("NFC", part["text"]) 31 | return content 32 | 33 | 34 | class _MessageBuilder: 35 | """ 36 | A class for building and managing messages in a chat conversation. 37 | Attributes: 38 | message (list): A list of dictionaries representing chat messages. 39 | model (str): The name of the ChatGPT model. 40 | token_count (int): The total number of tokens in the conversation. 41 | Methods: 42 | __init__(self, system_content: str, chatgpt_model: str): Initializes the MessageBuilder instance. 43 | insert_message(self, role: str, content: str, index: int = 1): Inserts a new message to the conversation. 44 | """ 45 | 46 | def __init__(self, system_content: str): 47 | self.system_message = ChatCompletionSystemMessageParam(role="system", content=normalize_content(system_content)) 48 | self.messages: list[ChatCompletionMessageParam] = [] 49 | 50 | @property 51 | def all_messages(self) -> list[ChatCompletionMessageParam]: 52 | return [self.system_message] + self.messages 53 | 54 | def insert_message( 55 | self, 56 | role: ChatCompletionRole, 57 | content: Union[str, Iterable[ChatCompletionContentPartParam], None], 58 | index: int = 0, 59 | tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] = None, 60 | tool_call_id: Optional[str] = None, 61 | ): 62 | """ 63 | Inserts a message into the conversation at the specified index, 64 | or at index 0 if no index is specified. 65 | Args: 66 | role (str): The role of the message sender (either "user", "system", or "assistant"). 67 | content (str | List[ChatCompletionContentPartParam]): The content of the message. 68 | index (int): The index at which to insert the message. 69 | """ 70 | message: ChatCompletionMessageParam 71 | if role == "user": 72 | message = ChatCompletionUserMessageParam(role="user", content=normalize_content(content)) 73 | elif role == "assistant" and isinstance(content, str): 74 | message = ChatCompletionAssistantMessageParam(role="assistant", content=normalize_content(content)) 75 | elif role == "assistant" and tool_calls is not None: 76 | message = ChatCompletionAssistantMessageParam(role="assistant", tool_calls=tool_calls) 77 | elif role == "tool" and tool_call_id is not None: 78 | message = ChatCompletionToolMessageParam( 79 | role="tool", tool_call_id=tool_call_id, content=normalize_content(content) 80 | ) 81 | else: 82 | raise ValueError("Invalid message for builder") 83 | self.messages.insert(index, message) 84 | 85 | 86 | def build_messages( 87 | model: str, 88 | system_prompt: str, 89 | *, 90 | tools: Optional[list[ChatCompletionToolParam]] = None, 91 | tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None, 92 | new_user_content: Union[str, list[ChatCompletionContentPartParam], None] = None, # list is for GPT4v usage 93 | past_messages: list[ChatCompletionMessageParam] = [], # *not* including system prompt 94 | few_shots: list[ChatCompletionMessageParam] = [], # will always be inserted after system prompt 95 | max_tokens: Optional[int] = None, 96 | fallback_to_default: bool = False, 97 | ) -> list[ChatCompletionMessageParam]: 98 | """ 99 | Build a list of messages for a chat conversation, given the system prompt, new user message, 100 | and past messages. The function will truncate the history of past messages if necessary to 101 | stay within the token limit. 102 | Args: 103 | model (str): The model name to use for token calculation, like gpt-3.5-turbo. 104 | system_prompt (str): The initial system prompt message. 105 | tools (list[ChatCompletionToolParam]): A list of tools to include in the conversation. 106 | tool_choice (ChatCompletionToolChoiceOptionParam): The tool to use in the conversation. 107 | new_user_content (str | List[ChatCompletionContentPartParam]): Content of new user message to append. 108 | past_messages (list[ChatCompletionMessageParam]): The list of past messages in the conversation. 109 | few_shots (list[ChatCompletionMessageParam]): A few-shot list of messages to insert after the system prompt. 110 | max_tokens (int): The maximum number of tokens allowed for the conversation. 111 | fallback_to_default (bool): Whether to fallback to default model if the model is not found. 112 | """ 113 | if max_tokens is None: 114 | max_tokens = get_token_limit(model, default_to_minimum=fallback_to_default) 115 | 116 | # Start with the required messages: system prompt, few-shots, and new user message 117 | message_builder = _MessageBuilder(system_prompt) 118 | 119 | for shot in reversed(few_shots): 120 | if shot["role"] is None or (shot.get("content") is None and shot.get("tool_calls") is None): 121 | raise ValueError("Few-shot messages must have role and either content or tool_calls") 122 | tool_call_id = shot.get("tool_call_id") 123 | if tool_call_id is not None and not isinstance(tool_call_id, str): 124 | raise ValueError("tool_call_id must be a string value") 125 | tool_calls = shot.get("tool_calls") 126 | if tool_calls is not None and not isinstance(tool_calls, Iterable): 127 | raise ValueError("tool_calls must be a list of tool calls") 128 | message_builder.insert_message( 129 | shot["role"], shot.get("content"), tool_calls=tool_calls, tool_call_id=tool_call_id # type: ignore[arg-type] 130 | ) 131 | 132 | append_index = len(few_shots) 133 | 134 | if new_user_content: 135 | message_builder.insert_message("user", new_user_content, index=append_index) 136 | 137 | total_token_count = count_tokens_for_system_and_tools( 138 | model, message_builder.system_message, tools, tool_choice, default_to_cl100k=fallback_to_default 139 | ) 140 | for existing_message in message_builder.messages: 141 | total_token_count += count_tokens_for_message(model, existing_message, default_to_cl100k=fallback_to_default) 142 | 143 | newest_to_oldest = list(reversed(past_messages)) 144 | for message in newest_to_oldest: 145 | potential_message_count = count_tokens_for_message(model, message, default_to_cl100k=fallback_to_default) 146 | if (total_token_count + potential_message_count) > max_tokens: 147 | logging.info("Reached max tokens of %d, history will be truncated", max_tokens) 148 | break 149 | 150 | if message["role"] is None or message["content"] is None: 151 | raise ValueError("Few-shot messages must have both role and content") 152 | message_builder.insert_message(message["role"], message["content"], index=append_index) # type: ignore[arg-type] 153 | total_token_count += potential_message_count 154 | return message_builder.all_messages 155 | -------------------------------------------------------------------------------- /src/openai_messages_token_helper/model_helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | import tiktoken 6 | from openai.types.chat import ( 7 | ChatCompletionMessageParam, 8 | ChatCompletionSystemMessageParam, 9 | ChatCompletionToolChoiceOptionParam, 10 | ChatCompletionToolParam, 11 | ) 12 | 13 | from .function_format import format_function_definitions 14 | from .images_helper import count_tokens_for_image 15 | 16 | MODELS_2_TOKEN_LIMITS = { 17 | "gpt-35-turbo": 4000, 18 | "gpt-3.5-turbo": 4000, 19 | "gpt-35-turbo-16k": 16000, 20 | "gpt-3.5-turbo-16k": 16000, 21 | "gpt-4": 8100, 22 | "gpt-4-32k": 32000, 23 | "gpt-4v": 128000, 24 | "gpt-4o": 128000, 25 | "gpt-4o-mini": 128000, 26 | # OpenAI specific model names: 27 | # https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4 28 | "gpt-4-0613": 8192, 29 | "gpt-4-turbo": 128000, 30 | "gpt-4-turbo-2024-04-09": 128000, 31 | "gpt-4-turbo-preview": 128000, 32 | "gpt-4-0125-preview": 128000, 33 | "gpt-4-1106-preview": 128000, 34 | } 35 | 36 | 37 | AOAI_2_OAI = {"gpt-35-turbo": "gpt-3.5-turbo", "gpt-35-turbo-16k": "gpt-3.5-turbo-16k", "gpt-4v": "gpt-4-turbo-vision"} 38 | 39 | logger = logging.getLogger("openai_messages_token_helper") 40 | 41 | 42 | def get_token_limit(model: str, default_to_minimum=False) -> int: 43 | """ 44 | Get the token limit for a given GPT model name (OpenAI.com or Azure OpenAI supported). 45 | Args: 46 | model (str): The name of the model to get the token limit for. 47 | default_to_minimum (bool): Whether to default to the minimum token limit if the model is not found. 48 | Returns: 49 | int: The token limit for the model. 50 | """ 51 | if model not in MODELS_2_TOKEN_LIMITS: 52 | if default_to_minimum: 53 | min_token_limit = min(MODELS_2_TOKEN_LIMITS.values()) 54 | logger.warning("Model %s not found, defaulting to minimum token limit %d", model, min_token_limit) 55 | return min_token_limit 56 | else: 57 | raise ValueError(f"Called with unknown model name: {model}") 58 | return MODELS_2_TOKEN_LIMITS[model] 59 | 60 | 61 | def encoding_for_model(model: str, default_to_cl100k=False) -> tiktoken.Encoding: 62 | """ 63 | Get the encoding for a given GPT model name (OpenAI.com or Azure OpenAI supported). 64 | Args: 65 | model (str): The name of the model to get the encoding for. 66 | default_to_cl100k (bool): Whether to default to the CL100k encoding if the model is not found. 67 | Returns: 68 | tiktoken.Encoding: The encoding for the model. 69 | """ 70 | if ( 71 | model == "" 72 | or model is None 73 | or (model not in AOAI_2_OAI and model not in MODELS_2_TOKEN_LIMITS and not default_to_cl100k) 74 | ): 75 | raise ValueError("Expected valid OpenAI GPT model name") 76 | model = AOAI_2_OAI.get(model, model) 77 | try: 78 | return tiktoken.encoding_for_model(model) 79 | except KeyError: 80 | if default_to_cl100k: 81 | logger.warning("Model %s not found, defaulting to CL100k encoding", model) 82 | return tiktoken.get_encoding("cl100k_base") 83 | else: 84 | raise 85 | 86 | 87 | def count_tokens_for_message(model: str, message: ChatCompletionMessageParam, default_to_cl100k=False) -> int: 88 | """ 89 | Calculate the number of tokens required to encode a message. Based off cookbook: 90 | https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb 91 | 92 | Args: 93 | model (str): The name of the model to use for encoding. 94 | message (Mapping): The message to encode, in a dictionary-like object. 95 | default_to_cl100k (bool): Whether to default to the CL100k encoding if the model is not found. 96 | Returns: 97 | int: The total number of tokens required to encode the message. 98 | 99 | >> model = 'gpt-3.5-turbo' 100 | >> message = {'role': 'user', 'content': 'Hello, how are you?'} 101 | >> count_tokens_for_message(model, message) 102 | 13 103 | """ 104 | encoding = encoding_for_model(model, default_to_cl100k) 105 | 106 | # Assumes we're using a recent model 107 | tokens_per_message = 3 108 | 109 | num_tokens = tokens_per_message 110 | for key, value in message.items(): 111 | if isinstance(value, list): 112 | # For GPT-4-vision support, based on https://github.com/openai/openai-cookbook/pull/881/files 113 | for item in value: 114 | # Note: item[type] does not seem to be counted in the token count 115 | if item["type"] == "text": 116 | num_tokens += len(encoding.encode(item["text"])) 117 | elif item["type"] == "image_url": 118 | num_tokens += count_tokens_for_image( 119 | item["image_url"]["url"], item["image_url"].get("detail", "auto"), model 120 | ) 121 | elif isinstance(value, str): 122 | num_tokens += len(encoding.encode(value)) 123 | else: 124 | raise ValueError(f"Could not encode unsupported message value type: {type(value)}") 125 | if key == "name": 126 | num_tokens += 1 127 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 128 | return num_tokens 129 | 130 | 131 | def count_tokens_for_system_and_tools( 132 | model: str, 133 | system_message: ChatCompletionSystemMessageParam | None = None, 134 | tools: list[ChatCompletionToolParam] | None = None, 135 | tool_choice: ChatCompletionToolChoiceOptionParam | None = None, 136 | default_to_cl100k: bool = False, 137 | ) -> int: 138 | """ 139 | Calculate the number of tokens required to encode a system message and tools. 140 | Both must be calculated together because the count is lower if both are present. 141 | Based on https://github.com/forestwanglin/openai-java/blob/main/jtokkit/src/main/java/xyz/felh/openai/jtokkit/utils/TikTokenUtils.java 142 | 143 | Args: 144 | model (str): The name of the model to use for encoding. 145 | tools (list[dict[str, dict]]): The tools to encode. 146 | tool_choice (str | dict): The tool choice to encode. 147 | system_message (dict): The system message to encode. 148 | default_to_cl100k (bool): Whether to default to the CL100k encoding if the model is not found. 149 | Returns: 150 | int: The total number of tokens required to encode the system message and tools. 151 | """ 152 | encoding = encoding_for_model(model, default_to_cl100k) 153 | 154 | tokens = 0 155 | if system_message: 156 | tokens += count_tokens_for_message(model, system_message, default_to_cl100k) 157 | if tools: 158 | tokens += len(encoding.encode(format_function_definitions(tools))) 159 | tokens += 9 # Additional tokens for function definition of tools 160 | # If there's a system message and tools are present, subtract four tokens 161 | if tools and system_message: 162 | tokens -= 4 163 | # If tool_choice is 'none', add one token. 164 | # If it's an object, add 4 + the number of tokens in the function name. 165 | # If it's undefined or 'auto', don't add anything. 166 | if tool_choice == "none": 167 | tokens += 1 168 | elif isinstance(tool_choice, dict): 169 | tokens += 7 170 | tokens += len(encoding.encode(tool_choice["function"]["name"])) 171 | return tokens 172 | -------------------------------------------------------------------------------- /src/openai_messages_token_helper/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pamelafox/openai-messages-token-helper/5c579fd05ca592f4813bae2a49af4bab9eafb6b5/src/openai_messages_token_helper/py.typed -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pamelafox/openai-messages-token-helper/5c579fd05ca592f4813bae2a49af4bab9eafb6b5/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pamelafox/openai-messages-token-helper/5c579fd05ca592f4813bae2a49af4bab9eafb6b5/tests/conftest.py -------------------------------------------------------------------------------- /tests/functions.py: -------------------------------------------------------------------------------- 1 | search_sources_toolchoice_auto = { 2 | "system_message": { 3 | "role": "system", 4 | "content": "You are a bot.", 5 | }, 6 | "tools": [ 7 | { 8 | "type": "function", 9 | "function": { 10 | "name": "search_sources", 11 | "description": "Retrieve sources from the Azure AI Search index", 12 | "parameters": { 13 | "type": "object", 14 | "properties": { 15 | "search_query": { 16 | "type": "string", 17 | "description": "Query string to retrieve documents from azure search eg: 'Health care plan'", 18 | } 19 | }, 20 | "required": ["search_query"], 21 | }, 22 | }, 23 | } 24 | ], 25 | "tool_choice": "auto", 26 | "count": 66, 27 | } 28 | 29 | search_sources_toolchoice_none = { 30 | "system_message": { 31 | "role": "system", 32 | "content": "You are a bot.", 33 | }, 34 | "tools": [ 35 | { 36 | "type": "function", 37 | "function": { 38 | "name": "search_sources", 39 | "description": "Retrieve sources from the Azure AI Search index", 40 | "parameters": { 41 | "type": "object", 42 | "properties": { 43 | "search_query": { 44 | "type": "string", 45 | "description": "Query string to retrieve documents from azure search eg: 'Health care plan'", 46 | } 47 | }, 48 | "required": ["search_query"], 49 | }, 50 | }, 51 | } 52 | ], 53 | "tool_choice": "none", 54 | "count": 67, 55 | } 56 | 57 | search_sources_toolchoice_name = { 58 | "system_message": { 59 | "role": "system", 60 | "content": "You are a bot.", 61 | }, 62 | "tools": [ 63 | { 64 | "type": "function", 65 | "function": { 66 | "name": "search_sources", 67 | "description": "Retrieve sources from the Azure AI Search index", 68 | "parameters": { 69 | "type": "object", 70 | "properties": { 71 | "search_query": { 72 | "type": "string", 73 | "description": "Query string to retrieve documents from azure search eg: 'Health care plan'", 74 | } 75 | }, 76 | "required": ["search_query"], 77 | }, 78 | }, 79 | } 80 | ], 81 | "tool_choice": {"type": "function", "function": {"name": "search_sources"}}, 82 | "count": 75, 83 | } 84 | 85 | integer_enum = { 86 | "system_message": { 87 | "role": "system", 88 | "content": "You are a bot.", 89 | }, 90 | "tools": [ 91 | { 92 | "type": "function", 93 | "function": { 94 | "name": "data_demonstration", 95 | "description": "This is the main function description", 96 | "parameters": {"type": "object", "properties": {"integer_enum": {"type": "integer", "enum": [-1, 1]}}}, 97 | }, 98 | } 99 | ], 100 | "tool_choice": "none", 101 | "count": 54, 102 | } 103 | 104 | 105 | integer_enum_tool_choice_name = { 106 | "system_message": { 107 | "role": "system", 108 | "content": "You are a bot.", 109 | }, 110 | "tools": [ 111 | { 112 | "type": "function", 113 | "function": { 114 | "name": "data_demonstration", 115 | "description": "This is the main function description", 116 | "parameters": {"type": "object", "properties": {"integer_enum": {"type": "integer", "enum": [-1, 1]}}}, 117 | }, 118 | } 119 | ], 120 | "tool_choice": { 121 | "type": "function", 122 | "function": {"name": "data_demonstration"}, 123 | }, # 4 tokens for "data_demonstration" 124 | "count": 64, 125 | } 126 | 127 | no_parameters = { 128 | "system_message": { 129 | "role": "system", 130 | "content": "You are a bot.", 131 | }, 132 | "tools": [ 133 | { 134 | "type": "function", 135 | "function": { 136 | "name": "search_sources", 137 | "description": "Retrieve sources from the Azure AI Search index", 138 | }, 139 | } 140 | ], 141 | "tool_choice": "auto", 142 | "count": 42, 143 | } 144 | 145 | no_parameters_tool_choice_name = { 146 | "system_message": { 147 | "role": "system", 148 | "content": "You are a bot.", 149 | }, 150 | "tools": [ 151 | { 152 | "type": "function", 153 | "function": { 154 | "name": "search_sources", 155 | "description": "Retrieve sources from the Azure AI Search index", 156 | }, 157 | } 158 | ], 159 | "tool_choice": {"type": "function", "function": {"name": "search_sources"}}, # 2 tokens for "search_sources" 160 | "count": 51, 161 | } 162 | 163 | no_parameter_description_or_required = { 164 | "system_message": { 165 | "role": "system", 166 | "content": "You are a bot.", 167 | }, 168 | "tools": [ 169 | { 170 | "type": "function", 171 | "function": { 172 | "name": "search_sources", 173 | "description": "Retrieve sources from the Azure AI Search index", 174 | "parameters": {"type": "object", "properties": {"search_query": {"type": "string"}}}, 175 | }, 176 | } 177 | ], 178 | "tool_choice": "auto", 179 | "count": 49, 180 | } 181 | 182 | no_parameter_description = { 183 | "system_message": { 184 | "role": "system", 185 | "content": "You are a bot.", 186 | }, 187 | "tools": [ 188 | { 189 | "type": "function", 190 | "function": { 191 | "name": "search_sources", 192 | "description": "Retrieve sources from the Azure AI Search index", 193 | "parameters": { 194 | "type": "object", 195 | "properties": {"search_query": {"type": "string"}}, 196 | "required": ["search_query"], 197 | }, 198 | }, 199 | } 200 | ], 201 | "tool_choice": "auto", 202 | "count": 49, 203 | } 204 | 205 | string_enum = { 206 | "system_message": { 207 | "role": "system", 208 | "content": "You are a bot.", 209 | }, 210 | "tools": [ 211 | { 212 | "type": "function", 213 | "function": { 214 | "name": "summarize_order", 215 | "description": "Summarize the customer order request", 216 | "parameters": { 217 | "type": "object", 218 | "properties": { 219 | "product_name": { 220 | "type": "string", 221 | "description": "Product name ordered by customer", 222 | }, 223 | "quantity": { 224 | "type": "integer", 225 | "description": "Quantity ordered by customer", 226 | }, 227 | "unit": { 228 | "type": "string", 229 | "enum": ["meals", "days"], 230 | "description": "unit of measurement of the customer order", 231 | }, 232 | }, 233 | "required": ["product_name", "quantity", "unit"], 234 | }, 235 | }, 236 | } 237 | ], 238 | "tool_choice": "none", 239 | "count": 86, 240 | } 241 | 242 | inner_object = { 243 | "system_message": { 244 | "role": "system", 245 | "content": "You are a bot.", 246 | }, 247 | "tools": [ 248 | { 249 | "type": "function", 250 | "function": { 251 | "name": "data_demonstration", 252 | "description": "This is the main function description", 253 | "parameters": { 254 | "type": "object", 255 | "properties": { 256 | "object_1": { 257 | "type": "object", 258 | "description": "The object data type as a property", 259 | "properties": { 260 | "string1": {"type": "string"}, 261 | }, 262 | } 263 | }, 264 | "required": ["object_1"], 265 | }, 266 | }, 267 | } 268 | ], 269 | "tool_choice": "none", 270 | "count": 65, # counted 67, over by 2 271 | } 272 | """ 273 | namespace functions { 274 | 275 | // This is the main function description 276 | type data_demonstration = (_: { 277 | // The object data type as a property 278 | object_1: { 279 | string1?: string, 280 | }, 281 | }) => any; 282 | 283 | } // namespace functions 284 | """ 285 | 286 | inner_object_with_enum_only = { 287 | "system_message": { 288 | "role": "system", 289 | "content": "You are a bot.", 290 | }, 291 | "tools": [ 292 | { 293 | "type": "function", 294 | "function": { 295 | "name": "data_demonstration", 296 | "description": "This is the main function description", 297 | "parameters": { 298 | "type": "object", 299 | "properties": { 300 | "object_1": { 301 | "type": "object", 302 | "description": "The object data type as a property", 303 | "properties": {"string_2a": {"type": "string", "enum": ["Happy", "Sad"]}}, 304 | } 305 | }, 306 | "required": ["object_1"], 307 | }, 308 | }, 309 | } 310 | ], 311 | "tool_choice": "none", 312 | "count": 73, # counted 74, over by 1 313 | } 314 | """ 315 | namespace functions { 316 | 317 | // This is the main function description 318 | type data_demonstration = (_: { 319 | // The object data type as a property 320 | object_1: { 321 | string_2a?: "Happy" | "Sad", 322 | }, 323 | }) => any; 324 | 325 | } // namespace functions 326 | """ 327 | 328 | inner_object_with_enum = { 329 | "system_message": { 330 | "role": "system", 331 | "content": "You are a bot.", 332 | }, 333 | "tools": [ 334 | { 335 | "type": "function", 336 | "function": { 337 | "name": "data_demonstration", 338 | "description": "This is the main function description", 339 | "parameters": { 340 | "type": "object", 341 | "properties": { 342 | "object_1": { 343 | "type": "object", 344 | "description": "The object data type as a property", 345 | "properties": { 346 | "string_2a": {"type": "string", "enum": ["Happy", "Sad"]}, 347 | "string_2b": { 348 | "type": "string", 349 | "description": "Description in a second object is lost", 350 | }, 351 | }, 352 | } 353 | }, 354 | "required": ["object_1"], 355 | }, 356 | }, 357 | } 358 | ], 359 | "tool_choice": "none", 360 | "count": 89, # counted 92, over by 3 361 | } 362 | """ 363 | namespace functions { 364 | 365 | // This is the main function description 366 | type data_demonstration = (_: { 367 | // The object data type as a property 368 | object_1: { 369 | string_2a?: "Happy" | "Sad", 370 | // Description in a second object is lost 371 | string_2b?: string, 372 | }, 373 | }) => any; 374 | 375 | } // namespace functions 376 | """ 377 | 378 | inner_object_and_string = { 379 | "system_message": { 380 | "role": "system", 381 | "content": "You are a bot.", 382 | }, 383 | "tools": [ 384 | { 385 | "type": "function", 386 | "function": { 387 | "name": "data_demonstration", 388 | "description": "This is the main function description", 389 | "parameters": { 390 | "type": "object", 391 | "properties": { 392 | "object_1": { 393 | "type": "object", 394 | "description": "The object data type as a property", 395 | "properties": { 396 | "string_2a": {"type": "string", "enum": ["Happy", "Sad"]}, 397 | "string_2b": { 398 | "type": "string", 399 | "description": "Description in a second object is lost", 400 | }, 401 | }, 402 | }, 403 | "string_1": {"type": "string", "description": "Not required gets a question mark"}, 404 | }, 405 | "required": ["object_1"], 406 | }, 407 | }, 408 | } 409 | ], 410 | "tool_choice": "none", 411 | "count": 103, # counted 106, over by 3 412 | } 413 | """ 414 | namespace functions { 415 | 416 | // This is the main function description 417 | type data_demonstration = (_: { 418 | // The object data type as a property 419 | object_1: { 420 | string_2a?: "Happy" | "Sad", 421 | // Description in a second object is lost 422 | string_2b?: string, 423 | }, 424 | // Not required gets a question mark 425 | string_1?: string, 426 | }) => any; 427 | 428 | } // namespace functions 429 | """ 430 | 431 | boolean = { 432 | "system_message": { 433 | "role": "system", 434 | "content": "You are a bot.", 435 | }, 436 | "tools": [ 437 | { 438 | "type": "function", 439 | "function": { 440 | "name": "human_escalation", 441 | "description": "Check if user wants to escalate to a human", 442 | "parameters": { 443 | "type": "object", 444 | "properties": { 445 | "requires_escalation": { 446 | "type": "boolean", 447 | "description": "If user is showing signs of frustration or anger in the query. Also if the user says they want to talk to a real person and not a chat bot.", 448 | } 449 | }, 450 | "required": ["requires_escalation"], 451 | }, 452 | }, 453 | } 454 | ], 455 | "tool_choice": "none", 456 | "count": 89, # over by 3 457 | } 458 | 459 | array = { 460 | "system_message": { 461 | "role": "system", 462 | "content": "You are a bot.", 463 | }, 464 | "tools": [ 465 | { 466 | "type": "function", 467 | "function": { 468 | "name": "get_coordinates", 469 | "description": "Get the latitude and longitude of multiple mailing addresses", 470 | "parameters": { 471 | "type": "object", 472 | "properties": { 473 | "addresses": { 474 | "type": "array", 475 | "description": "The mailing addresses to be located", 476 | "items": {"type": "string"}, 477 | } 478 | }, 479 | "required": ["addresses"], 480 | }, 481 | }, 482 | } 483 | ], 484 | "tool_choice": "none", 485 | "count": 59, 486 | } 487 | 488 | null = { 489 | "system_message": { 490 | "role": "system", 491 | "content": "You are a bot.", 492 | }, 493 | "tools": [ 494 | { 495 | "type": "function", 496 | "function": { 497 | "name": "get_null", 498 | "description": "Get the null value", 499 | "parameters": { 500 | "type": "object", 501 | "properties": { 502 | "null_value": { 503 | "type": "null", 504 | "description": "The null value to be returned", 505 | } 506 | }, 507 | "required": ["null_value"], 508 | }, 509 | }, 510 | } 511 | ], 512 | "tool_choice": "none", 513 | "count": 55, 514 | } 515 | 516 | no_type = { 517 | "system_message": { 518 | "role": "system", 519 | "content": "You are a bot.", 520 | }, 521 | "tools": [ 522 | { 523 | "type": "function", 524 | "function": { 525 | "name": "get_no_type", 526 | "description": "Get the no type value", 527 | "parameters": { 528 | "type": "object", 529 | "properties": { 530 | "no_type_value": { 531 | "description": "The no type value to be returned", 532 | } 533 | }, 534 | "required": ["no_type_value"], 535 | }, 536 | }, 537 | } 538 | ], 539 | "tool_choice": "none", 540 | "count": 59, 541 | } 542 | 543 | FUNCTION_COUNTS = [ 544 | inner_object, 545 | inner_object_and_string, 546 | inner_object_with_enum_only, 547 | inner_object_with_enum, 548 | search_sources_toolchoice_auto, 549 | search_sources_toolchoice_none, 550 | search_sources_toolchoice_name, 551 | integer_enum, 552 | integer_enum_tool_choice_name, 553 | no_parameters, 554 | no_parameters_tool_choice_name, 555 | no_parameter_description_or_required, 556 | no_parameter_description, 557 | string_enum, 558 | boolean, 559 | array, 560 | no_type, 561 | null, 562 | ] 563 | -------------------------------------------------------------------------------- /tests/image_large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pamelafox/openai-messages-token-helper/5c579fd05ca592f4813bae2a49af4bab9eafb6b5/tests/image_large.png -------------------------------------------------------------------------------- /tests/image_messages.py: -------------------------------------------------------------------------------- 1 | text_and_tiny_image_message = { 2 | "message": { 3 | "role": "user", 4 | "content": [ 5 | {"type": "text", "text": "Describe this picture:"}, 6 | { 7 | "type": "image_url", 8 | "image_url": { 9 | "url": "", 10 | "detail": "auto", 11 | }, 12 | }, 13 | ], 14 | }, 15 | "count": 266, 16 | "count_4o_mini": 8511, 17 | } 18 | 19 | text_and_tiny_image_message_nodetail = { 20 | "message": { 21 | "role": "user", 22 | "content": [ 23 | {"type": "text", "text": "Describe this picture:"}, 24 | { 25 | "type": "image_url", 26 | "image_url": { 27 | "url": "" 28 | }, 29 | }, 30 | ], 31 | }, 32 | "count": 266, 33 | "count_4o_mini": 8511, 34 | } 35 | 36 | text_and_tiny_image_message_low = { 37 | "message": { 38 | "role": "user", 39 | "content": [ 40 | {"type": "text", "text": "Describe this picture:"}, 41 | { 42 | "type": "image_url", 43 | "image_url": { 44 | "url": "", 45 | "detail": "low", 46 | }, 47 | }, 48 | ], 49 | }, 50 | "count": 96, # 11 + 85 51 | "count_4o_mini": 2844, # 11 + 2833 52 | } 53 | 54 | text_and_large_image_message = { 55 | "message": { 56 | "role": "user", 57 | "content": [ 58 | {"text": "hi", "type": "text"}, 59 | { 60 | "image_url": { 61 | "url": "", 62 | "detail": "auto", 63 | }, 64 | "type": "image_url", 65 | }, 66 | ], 67 | }, 68 | "count": 603, 69 | "count_4o_mini": 19842, 70 | } 71 | 72 | IMAGE_MESSAGE_COUNTS = [ 73 | text_and_tiny_image_message, 74 | text_and_tiny_image_message_nodetail, 75 | text_and_tiny_image_message_low, 76 | text_and_large_image_message, 77 | ] 78 | -------------------------------------------------------------------------------- /tests/messages.py: -------------------------------------------------------------------------------- 1 | system_message_short = { 2 | "message": { 3 | "role": "system", 4 | "content": "You are a bot.", 5 | }, 6 | "count": 12, 7 | "count_omni": 12, 8 | } 9 | 10 | system_message = { 11 | "message": { 12 | "role": "system", 13 | "content": "You are a helpful, pattern-following assistant that translates corporate jargon into plain English.", 14 | }, 15 | "count": 25, 16 | "count_omni": 24, 17 | } 18 | 19 | system_message_long = { 20 | "message": { 21 | "role": "system", 22 | "content": "Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.", 23 | }, 24 | "count": 31, 25 | "count_omni": 31, 26 | } 27 | 28 | system_message_unicode = { 29 | "message": { 30 | "role": "system", 31 | "content": "á", 32 | }, 33 | "count": 8, 34 | "count_omni": 8, 35 | } 36 | 37 | system_message_with_name = { 38 | "message": { 39 | "role": "system", 40 | "name": "example_user", 41 | "content": "New synergies will help drive top-line growth.", 42 | }, 43 | "count": 20, # Less tokens in older vision preview models 44 | "count_omni": 20, 45 | } 46 | 47 | user_message = { 48 | "message": { 49 | "role": "user", 50 | "content": "Hello, how are you?", 51 | }, 52 | "count": 13, 53 | "count_omni": 13, 54 | } 55 | 56 | user_message_unicode = { 57 | "message": { 58 | "role": "user", 59 | "content": "á", 60 | }, 61 | "count": 8, 62 | "count_omni": 8, 63 | } 64 | 65 | user_message_perf = { 66 | "message": { 67 | "role": "user", 68 | "content": "What happens in a performance review?", 69 | }, 70 | "count": 14, 71 | "count_omni": 14, 72 | } 73 | 74 | assistant_message_perf = { 75 | "message": { 76 | "role": "assistant", 77 | "content": "During the performance review at Contoso Electronics, the supervisor will discuss the employee's performance over the past year and provide feedback on areas for improvement. They will also provide an opportunity for the employee to discuss their goals and objectives for the upcoming year. The review is a two-way dialogue between managers and employees, and employees will receive a written summary of their performance review which will include a rating of their performance, feedback, and goals and objectives for the upcoming year [employee_handbook-3.pdf].", 78 | }, 79 | "count": 106, 80 | "count_omni": 106, 81 | } 82 | 83 | assistant_message_perf_short = { 84 | "message": { 85 | "role": "assistant", 86 | "content": "The supervisor will discuss the employee's performance and provide feedback on areas for improvement. They will also provide an opportunity for the employee to discuss their goals and objectives for the upcoming year. The review is a two-way dialogue between managers and employees, and employees will receive a written summary of their performance review which will include a rating of their performance, feedback, and goals for the upcoming year [employee_handbook-3.pdf].", 87 | }, 88 | "count": 91, 89 | "count_omni": 91, 90 | } 91 | 92 | user_message_dresscode = { 93 | "message": { 94 | "role": "user", 95 | "content": "Is there a dress code?", 96 | }, 97 | "count": 13, 98 | "count_omni": 13, 99 | } 100 | 101 | assistant_message_dresscode = { 102 | "message": { 103 | "role": "assistant", 104 | "content": "Yes, there is a dress code at Contoso Electronics. Look sharp! [employee_handbook-1.pdf]", 105 | }, 106 | "count": 30, 107 | "count_omni": 30, 108 | } 109 | user_message_pm = { 110 | "message": { 111 | "role": "user", 112 | "content": "What does a Product Manager do?", 113 | }, 114 | "count": 14, 115 | "count_omni": 14, 116 | } 117 | text_and_image_message = { 118 | "message": { 119 | "role": "user", 120 | "content": [ 121 | {"type": "text", "text": "Describe this picture:"}, 122 | { 123 | "type": "image_url", 124 | "image_url": { 125 | "url": "", 126 | "detail": "auto", 127 | }, 128 | }, 129 | ], 130 | }, 131 | "count": 266, 132 | "count_omni": 266, 133 | } 134 | 135 | MESSAGE_COUNTS = [ 136 | system_message, 137 | system_message_short, 138 | system_message_long, 139 | system_message_unicode, 140 | system_message_with_name, 141 | user_message, 142 | user_message_unicode, 143 | user_message_perf, 144 | user_message_dresscode, 145 | user_message_pm, 146 | assistant_message_perf, 147 | assistant_message_perf_short, 148 | assistant_message_dresscode, 149 | ] 150 | -------------------------------------------------------------------------------- /tests/test_imageshelper.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | import pytest 4 | 5 | from openai_messages_token_helper import count_tokens_for_image 6 | 7 | 8 | @pytest.fixture 9 | def small_image(): 10 | return "" 11 | 12 | 13 | @pytest.fixture 14 | def large_image(): 15 | large_image = open("tests/image_large.png", "rb").read() 16 | img = base64.b64encode(large_image).decode("utf-8") 17 | return f"data:image/png;base64,{img}" 18 | 19 | 20 | def test_count_tokens_for_image(small_image, large_image): 21 | assert count_tokens_for_image(small_image, "low") == 85 22 | assert count_tokens_for_image(small_image, "low", "gpt-4o-mini") == 2833 23 | assert count_tokens_for_image(small_image, "high") == 255 24 | assert count_tokens_for_image(small_image) == 255 25 | assert count_tokens_for_image(large_image, "low") == 85 26 | assert count_tokens_for_image(large_image, "high") == 1105 27 | with pytest.raises(ValueError, match="Invalid value for detail parameter."): 28 | assert count_tokens_for_image(large_image, "medium") 29 | with pytest.raises(ValueError, match="Image must be a base64 string."): 30 | assert count_tokens_for_image("http://domain.com/image.png") 31 | -------------------------------------------------------------------------------- /tests/test_messagebuilder.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | from openai.types.chat import ( 5 | ChatCompletionMessageParam, 6 | ChatCompletionToolChoiceOptionParam, 7 | ChatCompletionToolParam, 8 | ) 9 | 10 | from openai_messages_token_helper import build_messages, count_tokens_for_message 11 | 12 | from .functions import search_sources_toolchoice_auto 13 | from .image_messages import text_and_tiny_image_message 14 | from .messages import ( 15 | assistant_message_dresscode, 16 | assistant_message_perf, 17 | assistant_message_perf_short, 18 | system_message_long, 19 | system_message_short, 20 | system_message_unicode, 21 | user_message, 22 | user_message_dresscode, 23 | user_message_perf, 24 | user_message_pm, 25 | user_message_unicode, 26 | ) 27 | 28 | 29 | def test_messagebuilder(): 30 | messages = build_messages("gpt-35-turbo", system_message_short["message"]["content"]) 31 | assert messages == [system_message_short["message"]] 32 | assert count_tokens_for_message("gpt-35-turbo", messages[0]) == system_message_short["count"] 33 | 34 | 35 | def test_messagebuilder_imagemessage(): 36 | messages = build_messages( 37 | "gpt-35-turbo", 38 | system_message_short["message"]["content"], 39 | new_user_content=text_and_tiny_image_message["message"]["content"], 40 | ) 41 | assert messages == [system_message_short["message"], text_and_tiny_image_message["message"]] 42 | 43 | 44 | def test_messagebuilder_append(): 45 | messages = build_messages( 46 | "gpt-35-turbo", system_message_short["message"]["content"], new_user_content=user_message["message"]["content"] 47 | ) 48 | assert messages == [system_message_short["message"], user_message["message"]] 49 | assert count_tokens_for_message("gpt-35-turbo", messages[0]) == system_message_short["count"] 50 | assert count_tokens_for_message("gpt-35-turbo", messages[1]) == user_message["count"] 51 | 52 | 53 | def test_messagebuilder_unicode(): 54 | messages = build_messages("gpt-35-turbo", system_message_unicode["message"]["content"]) 55 | assert messages == [system_message_unicode["message"]] 56 | assert count_tokens_for_message("gpt-35-turbo", messages[0]) == system_message_unicode["count"] 57 | 58 | 59 | def test_messagebuilder_unicode_append(): 60 | messages = build_messages( 61 | "gpt-35-turbo", 62 | system_message_unicode["message"]["content"], 63 | new_user_content=user_message_unicode["message"]["content"], 64 | ) 65 | assert messages == [system_message_unicode["message"], user_message_unicode["message"]] 66 | assert count_tokens_for_message("gpt-35-turbo", messages[0]) == system_message_unicode["count"] 67 | assert count_tokens_for_message("gpt-35-turbo", messages[1]) == user_message_unicode["count"] 68 | 69 | 70 | def test_messagebuilder_model_error(): 71 | model = "phi-3" 72 | with pytest.raises(ValueError, match="Called with unknown model name: phi-3"): 73 | build_messages( 74 | model, system_message_short["message"]["content"], new_user_content=user_message["message"]["content"] 75 | ) 76 | 77 | 78 | def test_messagebuilder_model_fallback(): 79 | model = "phi-3" 80 | messages = build_messages( 81 | model, 82 | system_message_short["message"]["content"], 83 | new_user_content=user_message["message"]["content"], 84 | fallback_to_default=True, 85 | ) 86 | assert messages == [system_message_short["message"], user_message["message"]] 87 | assert count_tokens_for_message(model, messages[0], default_to_cl100k=True) == system_message_short["count"] 88 | assert count_tokens_for_message(model, messages[1], default_to_cl100k=True) == user_message["count"] 89 | 90 | 91 | def test_messagebuilder_pastmessages(): 92 | messages = build_messages( 93 | model="gpt-35-turbo", 94 | system_prompt=system_message_short["message"]["content"], # 12 tokens 95 | past_messages=[ 96 | user_message_perf["message"], # 14 tokens 97 | assistant_message_perf["message"], # 106 tokens 98 | ], 99 | new_user_content=user_message_pm["message"]["content"], # 14 tokens 100 | max_tokens=3000, 101 | ) 102 | assert messages == [ 103 | system_message_short["message"], 104 | user_message_perf["message"], 105 | assistant_message_perf["message"], 106 | user_message_pm["message"], 107 | ] 108 | 109 | 110 | def test_messagebuilder_pastmessages_truncated(): 111 | messages = build_messages( 112 | model="gpt-35-turbo", 113 | system_prompt=system_message_short["message"]["content"], # 12 tokens 114 | past_messages=[ 115 | user_message_perf["message"], # 14 tokens 116 | assistant_message_perf["message"], # 106 tokens 117 | ], 118 | new_user_content=user_message_pm["message"]["content"], # 14 tokens 119 | max_tokens=10, 120 | ) 121 | assert messages == [system_message_short["message"], user_message_pm["message"]] 122 | 123 | 124 | def test_messagebuilder_pastmessages_truncated_longer(): 125 | messages = build_messages( 126 | model="gpt-35-turbo", 127 | system_prompt=system_message_short["message"]["content"], # 12 tokens 128 | past_messages=[ 129 | user_message_perf["message"], # 14 tokens 130 | assistant_message_perf["message"], # 106 tokens 131 | user_message_dresscode["message"], # 13 tokens 132 | assistant_message_dresscode["message"], # 30 tokens 133 | ], 134 | new_user_content=user_message_pm["message"]["content"], # 14 tokens 135 | max_tokens=69, 136 | ) 137 | assert messages == [ 138 | system_message_short["message"], 139 | user_message_dresscode["message"], 140 | assistant_message_dresscode["message"], 141 | user_message_pm["message"], 142 | ] 143 | 144 | 145 | def test_messagebuilder_pastmessages_truncated_break_pair(): 146 | """Tests that the truncation breaks the pair of messages.""" 147 | messages = build_messages( 148 | model="gpt-35-turbo", 149 | system_prompt=system_message_short["message"]["content"], # 12 tokens 150 | past_messages=[ 151 | user_message_perf["message"], # 14 tokens 152 | assistant_message_perf_short["message"], # 91 tokens 153 | user_message_dresscode["message"], # 13 tokens 154 | assistant_message_dresscode["message"], # 30 tokens 155 | ], 156 | new_user_content=user_message_pm["message"]["content"], # 14 tokens 157 | max_tokens=160, 158 | ) 159 | assert messages == [ 160 | system_message_short["message"], 161 | assistant_message_perf_short["message"], 162 | user_message_dresscode["message"], 163 | assistant_message_dresscode["message"], 164 | user_message_pm["message"], 165 | ] 166 | 167 | 168 | def test_messagebuilder_system(): 169 | """Tests that the system message token count is considered.""" 170 | messages = build_messages( 171 | model="gpt-35-turbo", 172 | system_prompt=system_message_long["message"]["content"], # 31 tokens 173 | past_messages=[ 174 | user_message_perf["message"], # 14 tokens 175 | assistant_message_perf["message"], # 106 tokens 176 | user_message_dresscode["message"], # 13 tokens 177 | assistant_message_dresscode["message"], # 30 tokens 178 | ], 179 | new_user_content=user_message_pm["message"]["content"], # 14 tokens 180 | max_tokens=36, 181 | ) 182 | assert messages == [system_message_long["message"], user_message_pm["message"]] 183 | 184 | 185 | def test_messagebuilder_system_fewshots(): 186 | messages = build_messages( 187 | model="gpt-35-turbo", 188 | system_prompt=system_message_short["message"]["content"], 189 | new_user_content=user_message_pm["message"]["content"], 190 | past_messages=[], 191 | few_shots=[ 192 | {"role": "user", "content": "How did crypto do last year?"}, 193 | {"role": "assistant", "content": "Summarize Cryptocurrency Market Dynamics from last year"}, 194 | {"role": "user", "content": "What are my health plans?"}, 195 | {"role": "assistant", "content": "Show available health plans"}, 196 | ], 197 | ) 198 | # Make sure messages are in the right order 199 | assert messages[0]["role"] == "system" 200 | assert messages[1]["role"] == "user" 201 | assert messages[2]["role"] == "assistant" 202 | assert messages[3]["role"] == "user" 203 | assert messages[4]["role"] == "assistant" 204 | assert messages[5]["role"] == "user" 205 | assert messages[5]["content"] == user_message_pm["message"]["content"] 206 | 207 | 208 | def test_messagebuilder_system_fewshotstools(): 209 | messages = build_messages( 210 | model="gpt-35-turbo", 211 | system_prompt=system_message_short["message"]["content"], 212 | new_user_content=user_message_pm["message"]["content"], 213 | past_messages=[], 214 | few_shots=[ 215 | {"role": "user", "content": "good options for climbing gear that can be used outside?"}, 216 | { 217 | "role": "assistant", 218 | "tool_calls": [ 219 | { 220 | "id": "call_abc123", 221 | "type": "function", 222 | "function": { 223 | "arguments": '{"search_query":"climbing gear outside"}', 224 | "name": "search_database", 225 | }, 226 | } 227 | ], 228 | }, 229 | { 230 | "role": "tool", 231 | "tool_call_id": "call_abc123", 232 | "content": "Search results for climbing gear that can be used outside: ...", 233 | }, 234 | {"role": "user", "content": "are there any shoes less than $50?"}, 235 | { 236 | "role": "assistant", 237 | "tool_calls": [ 238 | { 239 | "id": "call_abc456", 240 | "type": "function", 241 | "function": { 242 | "arguments": '{"search_query":"shoes","price_filter":{"comparison_operator":"<","value":50}}', 243 | "name": "search_database", 244 | }, 245 | } 246 | ], 247 | }, 248 | {"role": "tool", "tool_call_id": "call_abc456", "content": "Search results for shoes cheaper than 50: ..."}, 249 | ], 250 | ) 251 | # Make sure messages are in the right order 252 | assert messages[0]["role"] == "system" 253 | assert messages[1]["role"] == "user" 254 | assert messages[2]["role"] == "assistant" 255 | assert messages[3]["role"] == "tool" 256 | assert messages[4]["role"] == "user" 257 | assert messages[5]["role"] == "assistant" 258 | assert messages[6]["role"] == "tool" 259 | assert messages[7]["role"] == "user" 260 | assert messages[7]["content"] == user_message_pm["message"]["content"] 261 | 262 | 263 | def test_messagebuilder_system_tools(): 264 | """Tests that the system message token count is considered.""" 265 | messages = build_messages( 266 | model="gpt-35-turbo", 267 | system_prompt=search_sources_toolchoice_auto["system_message"]["content"], 268 | tools=search_sources_toolchoice_auto["tools"], 269 | tool_choice=search_sources_toolchoice_auto["tool_choice"], 270 | # 66 tokens for system + tools + tool_choice ^ 271 | past_messages=[ 272 | user_message_perf["message"], # 14 tokens 273 | assistant_message_perf["message"], # 106 tokens 274 | ], 275 | new_user_content=user_message_pm["message"]["content"], # 14 tokens 276 | max_tokens=90, 277 | ) 278 | assert messages == [search_sources_toolchoice_auto["system_message"], user_message_pm["message"]] 279 | 280 | 281 | def test_messagebuilder_typing() -> None: 282 | tools: list[ChatCompletionToolParam] = [ 283 | { 284 | "type": "function", 285 | "function": { 286 | "name": "search_sources", 287 | "description": "Retrieve sources from the Azure AI Search index", 288 | "parameters": { 289 | "type": "object", 290 | "properties": { 291 | "search_query": { 292 | "type": "string", 293 | "description": "Query string to retrieve documents from azure search eg: 'Health care plan'", 294 | } 295 | }, 296 | "required": ["search_query"], 297 | }, 298 | }, 299 | } 300 | ] 301 | tool_choice: ChatCompletionToolChoiceOptionParam = { 302 | "type": "function", 303 | "function": {"name": "search_sources"}, 304 | } 305 | 306 | past_messages: list[ChatCompletionMessageParam] = [ 307 | {"role": "user", "content": "What are my health plans?"}, 308 | {"role": "assistant", "content": "Here are some tools you can use to search for sources."}, 309 | ] 310 | 311 | messages = build_messages( 312 | model="gpt-35-turbo", 313 | system_prompt="Here are some tools you can use to search for sources.", 314 | tools=tools, 315 | tool_choice=tool_choice, 316 | past_messages=past_messages, 317 | new_user_content="What are my health plans?", 318 | max_tokens=90, 319 | ) 320 | 321 | assert isinstance(messages, list) 322 | if hasattr(typing, "assert_type"): 323 | typing.assert_type(messages[0], ChatCompletionMessageParam) 324 | 325 | messages = build_messages( 326 | model="gpt-35-turbo", 327 | system_prompt="Here are some tools you can use to search for sources.", 328 | tools=tools, 329 | tool_choice="auto", 330 | past_messages=past_messages, 331 | new_user_content="What are my health plans?", 332 | max_tokens=90, 333 | ) 334 | 335 | assert isinstance(messages, list) 336 | if hasattr(typing, "assert_type"): 337 | typing.assert_type(messages[0], ChatCompletionMessageParam) 338 | -------------------------------------------------------------------------------- /tests/test_modelhelper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from openai_messages_token_helper import count_tokens_for_message, count_tokens_for_system_and_tools, get_token_limit 4 | 5 | from .functions import FUNCTION_COUNTS, search_sources_toolchoice_auto 6 | from .image_messages import IMAGE_MESSAGE_COUNTS 7 | from .messages import system_message, system_message_with_name, user_message 8 | 9 | 10 | def test_get_token_limit(): 11 | assert get_token_limit("gpt-35-turbo") == 4000 12 | assert get_token_limit("gpt-3.5-turbo") == 4000 13 | assert get_token_limit("gpt-35-turbo-16k") == 16000 14 | assert get_token_limit("gpt-3.5-turbo-16k") == 16000 15 | assert get_token_limit("gpt-4") == 8100 16 | assert get_token_limit("gpt-4-32k") == 32000 17 | assert get_token_limit("gpt-4o") == 128000 18 | 19 | 20 | def test_get_token_limit_error(): 21 | with pytest.raises(ValueError, match="Called with unknown model name: gpt-3"): 22 | get_token_limit("gpt-3") 23 | 24 | 25 | def test_get_token_limit_default(caplog): 26 | with caplog.at_level("WARNING"): 27 | assert get_token_limit("gpt-3", default_to_minimum=True) == 4000 28 | assert "Model gpt-3 not found, defaulting to minimum token limit 4000" in caplog.text 29 | 30 | 31 | # parameterize the model and the expected number of tokens 32 | @pytest.mark.parametrize( 33 | "model, count_key", 34 | [ 35 | ("gpt-35-turbo", "count"), 36 | ("gpt-3.5-turbo", "count"), 37 | ("gpt-35-turbo-16k", "count"), 38 | ("gpt-3.5-turbo-16k", "count"), 39 | ("gpt-4", "count"), 40 | ("gpt-4-32k", "count"), 41 | ("gpt-4v", "count"), 42 | ("gpt-4o", "count_omni"), 43 | ], 44 | ) 45 | @pytest.mark.parametrize( 46 | "message", 47 | [ 48 | user_message, 49 | system_message, 50 | system_message_with_name, 51 | ], 52 | ) 53 | def test_count_tokens_for_message(model, count_key, message): 54 | assert count_tokens_for_message(model, message["message"]) == message[count_key] 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "model, count_key", 59 | [ 60 | ("gpt-4", "count"), 61 | ("gpt-4o", "count"), 62 | ("gpt-4o-mini", "count_4o_mini"), 63 | ], 64 | ) 65 | def test_count_tokens_for_message_list(model, count_key): 66 | for message_count_pair in IMAGE_MESSAGE_COUNTS: 67 | assert count_tokens_for_message(model, message_count_pair["message"]) == message_count_pair[count_key] 68 | 69 | 70 | def test_count_tokens_for_message_error(): 71 | message = { 72 | "role": "user", 73 | "content": {"key": "value"}, 74 | } 75 | model = "gpt-35-turbo" 76 | with pytest.raises(ValueError, match="Could not encode unsupported message value type"): 77 | count_tokens_for_message(model, message) 78 | 79 | 80 | def test_count_tokens_for_message_model_error(): 81 | with pytest.raises(ValueError, match="Expected valid OpenAI GPT model name"): 82 | count_tokens_for_message("", user_message["message"]) 83 | with pytest.raises(ValueError, match="Expected valid OpenAI GPT model name"): 84 | count_tokens_for_message(None, user_message["message"]) 85 | with pytest.raises(ValueError, match="Expected valid OpenAI GPT model name"): 86 | count_tokens_for_message("gpt44", user_message["message"]) 87 | 88 | 89 | def test_count_tokens_for_message_model_default(caplog): 90 | model = "phi-3" 91 | with caplog.at_level("WARNING"): 92 | assert count_tokens_for_message(model, user_message["message"], default_to_cl100k=True) == user_message["count"] 93 | assert "Model phi-3 not found, defaulting to CL100k encoding" in caplog.text 94 | 95 | 96 | @pytest.mark.parametrize( 97 | "function_count_pair", 98 | FUNCTION_COUNTS, 99 | ) 100 | def test_count_tokens_for_system_and_tools(function_count_pair): 101 | counted_tokens = count_tokens_for_system_and_tools( 102 | "gpt-35-turbo", 103 | function_count_pair["system_message"], 104 | function_count_pair["tools"], 105 | function_count_pair["tool_choice"], 106 | ) 107 | expected_tokens = function_count_pair["count"] 108 | diff = counted_tokens - expected_tokens 109 | assert ( 110 | diff >= 0 and diff <= 3 111 | ), f"Expected {expected_tokens} tokens, got {counted_tokens}. Counted tokens is only allowed to be off by 3 in the over-counting direction." 112 | 113 | 114 | def test_count_tokens_for_system_and_tools_fallback(caplog): 115 | function_count_pair = search_sources_toolchoice_auto 116 | with caplog.at_level("WARNING"): 117 | counted_tokens = count_tokens_for_system_and_tools( 118 | "llama-3.1", 119 | function_count_pair["system_message"], 120 | function_count_pair["tools"], 121 | function_count_pair["tool_choice"], 122 | default_to_cl100k=True, 123 | ) 124 | assert counted_tokens == function_count_pair["count"] 125 | assert "Model llama-3.1 not found, defaulting to CL100k encoding" in caplog.text 126 | -------------------------------------------------------------------------------- /tests/verify_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | 4 | import azure.identity 5 | import openai 6 | from dotenv import load_dotenv 7 | from functions import FUNCTION_COUNTS # type: ignore[import-not-found] 8 | 9 | # Setup the OpenAI client to use either Azure OpenAI or OpenAI API 10 | load_dotenv() 11 | API_HOST = os.getenv("API_HOST") 12 | 13 | client: Union[openai.OpenAI, openai.AzureOpenAI] 14 | 15 | if API_HOST == "azure": 16 | 17 | if (azure_openai_version := os.getenv("AZURE_OPENAI_VERSION")) is None: 18 | raise ValueError("Missing Azure OpenAI version") 19 | if (azure_openai_endpoint := os.getenv("AZURE_OPENAI_ENDPOINT")) is None: 20 | raise ValueError("Missing Azure OpenAI endpoint") 21 | if (azure_openai_deployment := os.getenv("AZURE_OPENAI_DEPLOYMENT")) is None: 22 | raise ValueError("Missing Azure OpenAI deployment") 23 | 24 | token_provider = azure.identity.get_bearer_token_provider( 25 | azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" 26 | ) 27 | client = openai.AzureOpenAI( 28 | api_version=azure_openai_version, 29 | azure_endpoint=azure_openai_endpoint, 30 | azure_ad_token_provider=token_provider, 31 | ) 32 | MODEL_NAME = azure_openai_deployment 33 | else: 34 | if (openai_key := os.getenv("OPENAI_KEY")) is None: 35 | raise ValueError("Missing OpenAI API key") 36 | if (openai_model := os.getenv("OPENAI_MODEL")) is None: 37 | raise ValueError("Missing OpenAI model") 38 | client = openai.OpenAI(api_key=openai_key) 39 | MODEL_NAME = openai_model 40 | 41 | 42 | # Test the token count for each message 43 | for function_count_pair in FUNCTION_COUNTS: 44 | response = client.chat.completions.create( # type: ignore[call-overload] 45 | model=MODEL_NAME, 46 | temperature=0.7, 47 | n=1, 48 | messages=[function_count_pair["system_message"]], 49 | tools=function_count_pair["tools"], 50 | tool_choice=function_count_pair["tool_choice"], 51 | ) 52 | 53 | print(function_count_pair["tools"]) 54 | assert response.usage is not None, "Expected usage to be present" 55 | assert ( 56 | response.usage.prompt_tokens == function_count_pair["count"] 57 | ), f"Expected {function_count_pair['count']} tokens, got {response.usage.prompt_tokens}" 58 | -------------------------------------------------------------------------------- /tests/verify_openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | 4 | import azure.identity 5 | import openai 6 | from dotenv import load_dotenv 7 | from image_messages import IMAGE_MESSAGE_COUNTS # type: ignore[import-not-found] 8 | from messages import MESSAGE_COUNTS # type: ignore[import-not-found] 9 | 10 | # Setup the OpenAI client to use either Azure OpenAI or OpenAI API 11 | load_dotenv() 12 | API_HOST = os.getenv("API_HOST") 13 | 14 | client: Union[openai.OpenAI, openai.AzureOpenAI] 15 | 16 | if API_HOST == "azure": 17 | if (azure_openai_version := os.getenv("AZURE_OPENAI_VERSION")) is None: 18 | raise ValueError("Missing Azure OpenAI version") 19 | if (azure_openai_endpoint := os.getenv("AZURE_OPENAI_ENDPOINT")) is None: 20 | raise ValueError("Missing Azure OpenAI endpoint") 21 | if (azure_openai_deployment := os.getenv("AZURE_OPENAI_DEPLOYMENT")) is None: 22 | raise ValueError("Missing Azure OpenAI deployment") 23 | 24 | token_provider = azure.identity.get_bearer_token_provider( 25 | azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" 26 | ) 27 | client = openai.AzureOpenAI( 28 | api_version=azure_openai_version, 29 | azure_endpoint=azure_openai_endpoint, 30 | azure_ad_token_provider=token_provider, 31 | ) 32 | MODEL_NAME = azure_openai_deployment 33 | else: 34 | if (openai_key := os.getenv("OPENAI_KEY")) is None: 35 | raise ValueError("Missing OpenAI API key") 36 | if (openai_model := os.getenv("OPENAI_MODEL")) is None: 37 | raise ValueError("Missing OpenAI model") 38 | client = openai.OpenAI(api_key=openai_key) 39 | MODEL_NAME = openai_model 40 | 41 | # Test the token count for each message 42 | 43 | for message_count_pair in MESSAGE_COUNTS: 44 | for model, expected_tokens in [("gpt-4o", message_count_pair["count_omni"])]: 45 | message = message_count_pair["message"] 46 | expected_tokens = message_count_pair["count"] 47 | response = client.chat.completions.create( 48 | model=MODEL_NAME, 49 | temperature=0.7, 50 | n=1, 51 | messages=[message], # type: ignore[list-item] 52 | ) 53 | 54 | print(message) 55 | assert response.usage is not None, "Expected usage to be present" 56 | assert ( 57 | response.usage.prompt_tokens == expected_tokens 58 | ), f"Expected {expected_tokens} tokens, got {response.usage.prompt_tokens} for model {MODEL_NAME}" 59 | 60 | 61 | for message_count_pair in IMAGE_MESSAGE_COUNTS: 62 | for model, expected_tokens in [ 63 | ("gpt-4o", message_count_pair["count"]), 64 | ("gpt-4o-mini", message_count_pair["count_4o_mini"]), 65 | ]: 66 | response = client.chat.completions.create( 67 | model=model, 68 | temperature=0.7, 69 | n=1, 70 | messages=[message_count_pair["message"]], # type: ignore[list-item] 71 | ) 72 | 73 | assert response.usage is not None, "Expected usage to be present" 74 | assert ( 75 | response.usage.prompt_tokens == expected_tokens 76 | ), f"Expected {expected_tokens} tokens, got {response.usage.prompt_tokens} for model {model}" 77 | --------------------------------------------------------------------------------