├── texttunnel ├── __init__.py ├── models.py ├── utils.py ├── chat.py └── processor.py ├── tests ├── test_models.py ├── test_processor.py ├── test_utils.py ├── test_chat.py └── conftest.py ├── docs ├── Makefile ├── make.bat ├── conf.py └── index.rst ├── pyproject.toml ├── LICENSE ├── .github └── workflows │ ├── python-package.yml │ └── build-docs.yml ├── examples ├── text_classification.py ├── named_entity_recognition.py └── sentiment_analysis.py ├── CHANGELOG.md ├── .gitignore └── README.md /texttunnel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from texttunnel import models 4 | 5 | 6 | def test_model_errors_on_negative(): 7 | with pytest.raises(ValueError): 8 | models.Model( 9 | name="gpt-3.5-turbo", 10 | context_size=0, 11 | input_token_price_per_1k=-1, 12 | output_token_price_per_1k=0.004, 13 | tokens_per_minute=90000, 14 | requests_per_minute=3500, 15 | ) 16 | 17 | 18 | def test_parameters_invalid_value(): 19 | with pytest.raises(ValueError): 20 | models.Parameters(max_tokens=128, frequency_penalty=3) 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # pyproject.toml 2 | 3 | [build-system] 4 | requires = ["poetry-core>=1.0.0"] 5 | build-backend = "poetry.core.masonry.api" 6 | 7 | [tool.poetry] 8 | name = "texttunnel" 9 | version = "0.3.7" 10 | description = "Efficient text processing with the OpenAI API" 11 | authors = ["Q Agentur für Forschung GmbH "] 12 | readme = "README.md" 13 | homepage = "https://github.com/qagentur/texttunnel" 14 | license = "MIT" 15 | 16 | [tool.poetry.dependencies] 17 | python = ">=3.9 <4.0" 18 | aiohttp = ">=3.8.3 <4.0.0" 19 | jsonschema = ">=3.0.0 <5.0.0" 20 | tiktoken = ">=0.3.1 <1.0.0" 21 | aiohttp-client-cache = ">=0.8.0 <1.0.0" 22 | 23 | [tool.poetry.group.dev.dependencies] 24 | black = "^23.7.0" 25 | pytest = "^7.4.0" 26 | ipykernel = "^6.25.0" 27 | aioboto3 = "^11.3.0" 28 | aiosqlite = "^0.19.0" 29 | 30 | [tool.poetry.group.docs.dependencies] 31 | sphinx = "6.2.1" 32 | sphinx-rtd-theme = "^1.2.2" 33 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Q Agentur für Forschung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = "texttunnel" 10 | copyright = "2023, Q Agentur für Forschung GmbH" 11 | author = "Q Agentur für Forschung GmbH" 12 | release = "0.2.1" 13 | 14 | # -- General configuration --------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 16 | 17 | extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx.ext.autosummary"] 18 | 19 | templates_path = ["_templates"] 20 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 21 | 22 | 23 | # -- Options for HTML output ------------------------------------------------- 24 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 25 | 26 | html_theme = "sphinx_rtd_theme" 27 | html_static_path = ["_static"] 28 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | ci: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.9", "3.10", "3.11"] 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Install Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install poetry 24 | uses: abatilo/actions-poetry@v2 25 | - name: Setup a local virtual environment 26 | run: | 27 | poetry config virtualenvs.create true --local 28 | poetry config virtualenvs.in-project true --local 29 | - uses: actions/cache@v3 30 | name: Define a cache for the virtual environment 31 | with: 32 | path: ./.venv 33 | key: venv-${{ hashFiles('poetry.lock') }} 34 | - name: Install the project dependencies 35 | run: poetry install 36 | - name: Run linter 37 | run: poetry run black --check . 38 | - name: Run the automated tests 39 | run: poetry run pytest -v 40 | -------------------------------------------------------------------------------- /.github/workflows/build-docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and Deploy Docs 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | release: 7 | types: [ published ] 8 | 9 | jobs: 10 | build-and-deploy: 11 | runs-on: ubuntu-latest 12 | permissions: 13 | contents: write 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Install Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: 3.9.12 20 | - name: Install poetry 21 | uses: abatilo/actions-poetry@v2 22 | - name: Setup a local virtual environment 23 | run: | 24 | poetry config virtualenvs.create true --local 25 | poetry config virtualenvs.in-project true --local 26 | - uses: actions/cache@v3 27 | name: Define a cache for the virtual environment 28 | with: 29 | path: ./.venv 30 | key: venv-${{ hashFiles('poetry.lock') }} 31 | - name: Install the docs dependencies 32 | run: poetry install --only docs 33 | - name: Build docs with Sphinx 34 | run: | 35 | poetry run make -C docs html 36 | - name: Deploy docs to Github Pages 37 | uses: peaceiris/actions-gh-pages@v3 38 | with: 39 | github_token: ${{ secrets.GITHUB_TOKEN }} 40 | publish_dir: ./docs/_build/html 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /tests/test_processor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pytest 3 | from texttunnel import processor 4 | 5 | 6 | def test_prepare_output_filepath_tempfile(): 7 | path = processor.prepare_output_filepath(None, keep_file=False) 8 | try: 9 | assert isinstance(path, Path) 10 | assert path.exists() 11 | finally: 12 | path.unlink() 13 | 14 | 15 | def test_is_valid_response(response_fixture): 16 | assert processor.is_valid_response(response_fixture, print_errors=True) 17 | 18 | 19 | def test_is_valid_response_fails_on_invalid_response(response_fixture): 20 | invalid_response = response_fixture.copy() 21 | del invalid_response[1]["usage"] 22 | assert not processor.is_valid_response(invalid_response) 23 | 24 | 25 | def test_parse_response(response_fixture): 26 | act = processor.parse_arguments(response_fixture) 27 | exp = {"feeling": "happy"} 28 | 29 | assert act == exp 30 | 31 | 32 | def test_parse_token_usage(response_fixture): 33 | act = processor.parse_token_usage(response_fixture) 34 | exp = { 35 | "prompt_tokens": 100, 36 | "completion_tokens": 50, 37 | "total_tokens": 150, 38 | } 39 | 40 | assert act == exp 41 | 42 | 43 | def test_usage_to_cost(response_fixture, model_fixture): 44 | usage = processor.parse_token_usage(response_fixture) 45 | cost = processor.usage_to_cost(usage, model_fixture) 46 | 47 | assert cost > 0 48 | 49 | 50 | def test_process_api_requests_fails_on_duplicate_requests(requests_fixture): 51 | with pytest.raises(ValueError): 52 | processor.process_api_requests( 53 | requests=[ 54 | requests_fixture[0], 55 | requests_fixture[0], 56 | ] 57 | ) 58 | -------------------------------------------------------------------------------- /examples/text_classification.py: -------------------------------------------------------------------------------- 1 | # Example of using the texttunnel package to perform text classification 2 | # Uses the aprocess_api_call for more control over the event loop 3 | 4 | # %% 5 | import logging 6 | from aiohttp_client_cache import SQLiteBackend 7 | import asyncio 8 | import nest_asyncio 9 | 10 | from texttunnel import chat, models, processor 11 | 12 | nest_asyncio.apply() # to allow for asyncio.run() within Jupyter 13 | 14 | # %% 15 | # Create a SQLite cache to store the results of the requests 16 | # When this script is run again, the results will be loaded from the cache 17 | # Requires the additional package aiosqlite (pip install aiosqlite) 18 | cache = SQLiteBackend(cache_name="openai_cache.sqlite", allowed_methods=["POST"]) 19 | 20 | logging.basicConfig(level=logging.WARN) 21 | logging.getLogger("texttunnel").setLevel(logging.INFO) 22 | 23 | # Texts that we'd like to know the sentiment of 24 | input_texts = [ 25 | "The 60% layout is great for travel, but I wish it had arrow keys", 26 | "The laser doesn't work on my glass desk. I'm returning it.", 27 | "I love the feel of the keys, but the RGB lighting is too bright.", 28 | "The scroll wheel is too sensitive. I keep scrolling past what I want.", 29 | ] 30 | 31 | # Describe the output format that we'd like to receive, 32 | # using JSON Schema 33 | function = { 34 | "name": "text_classification", 35 | "parameters": { 36 | "type": "object", 37 | "properties": { 38 | "category": { 39 | "type": "string", 40 | "enum": ["keyboard", "mouse"], 41 | }, 42 | }, 43 | "required": ["answers"], 44 | }, 45 | } 46 | 47 | system_message = "Classify reviews by product category." 48 | 49 | model = models.GPT_3_5_TURBO 50 | 51 | requests = chat.build_requests( 52 | texts=input_texts, 53 | function=function, 54 | model=model, 55 | system_message=system_message, 56 | params=models.Parameters(max_tokens=50), 57 | ) 58 | 59 | # %% 60 | # Create an event loop and run the requests 61 | # Alternatively, use processor.process_api_requests() and let it handle the event loop 62 | loop = asyncio.get_event_loop() 63 | responses = loop.run_until_complete( 64 | processor.aprocess_api_requests(requests, cache=cache) 65 | ) 66 | 67 | # %% 68 | # Display the results 69 | results = [ 70 | processor.parse_arguments(response=response)["category"] for response in responses 71 | ] 72 | 73 | for text, result in zip(input_texts, results): 74 | print(f"{text}: {result}") 75 | 76 | # %% 77 | -------------------------------------------------------------------------------- /examples/named_entity_recognition.py: -------------------------------------------------------------------------------- 1 | # Example of using the texttunnel package to perform named entity recognition 2 | # Script requires that the OPENAI_API_KEY environment variable is set. 3 | 4 | from texttunnel import chat, models, processor 5 | 6 | # Texts that we'd like to extract entities from 7 | input_texts = [ 8 | "BioNTech SE is set to acquire InstaDeep, \ 9 | a Tunis-born and U.K.-based artificial intelligence \ 10 | (AI) startup, for up to £562 million", 11 | "The U.S. Food and Drug Administration (FDA) \ 12 | has approve Pfizer-BioNTech's COVID-19 vaccine for emergency use", 13 | "BioNTech founders, Dr. Ugur Sahin and Dr. Ozlem Tureci, \ 14 | receive prestigious award for their vaccine research", 15 | ] 16 | 17 | 18 | # Describe the output format that we'd like to receive, 19 | # using JSON Schema. We specify that we want to extract 20 | # persons, organizations, and locations from the text. 21 | function = { 22 | "name": "ner", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "persons": { 27 | "type": "array", 28 | "items": { 29 | "type": "string", 30 | }, 31 | }, 32 | "organizations": { 33 | "type": "array", 34 | "items": { 35 | "type": "string", 36 | }, 37 | }, 38 | "locations": { 39 | "type": "array", 40 | "items": { 41 | "type": "string", 42 | }, 43 | }, 44 | }, 45 | "required": ["persons", "organizations", "locations"], 46 | }, 47 | } 48 | 49 | system_message = "Extract named entities from a text." 50 | 51 | model = models.GPT_4o 52 | 53 | requests = chat.build_requests( 54 | texts=input_texts, 55 | function=function, 56 | model=model, 57 | system_message=system_message, 58 | params=models.Parameters(max_tokens=256), 59 | ) 60 | 61 | # Estimate the cost of the requests 62 | estimated_cost_usd = sum([r.estimate_cost_usd() for r in requests]) 63 | print(f"Estimated cost: ${estimated_cost_usd:.4f}") 64 | 65 | 66 | responses = processor.process_api_requests(requests=requests) 67 | 68 | 69 | results = [processor.parse_arguments(response=response) for response in responses] 70 | print(results) 71 | 72 | actual_cost_usd = sum( 73 | [ 74 | processor.usage_to_cost( 75 | usage=processor.parse_token_usage(response=response), 76 | model=model, 77 | ) 78 | for response in responses 79 | ] 80 | ) 81 | 82 | print(f"Actual cost: ${actual_cost_usd:.4f}") 83 | -------------------------------------------------------------------------------- /examples/sentiment_analysis.py: -------------------------------------------------------------------------------- 1 | # Example of using the texttunnel package to perform sentiment analysis 2 | # Features binpacking to reduce the number of API calls 3 | 4 | # %% 5 | import logging 6 | from aiohttp_client_cache import SQLiteBackend 7 | 8 | from texttunnel import chat, models, processor 9 | 10 | # Create a SQLite cache to store the results of the requests 11 | # When this script is run again, the results will be loaded from the cache 12 | # Requires the additional package aiosqlite (pip install aiosqlite) 13 | cache = SQLiteBackend(cache_name="openai_cache.sqlite", allowed_methods=["POST"]) 14 | 15 | logging.basicConfig(level=logging.WARN) 16 | logging.getLogger("texttunnel").setLevel(logging.INFO) 17 | 18 | # Texts that we'd like to know the sentiment of 19 | input_texts = [ 20 | "I love sunshine", 21 | "I don't like rain", 22 | ] 23 | 24 | # Describe the output format that we'd like to receive, 25 | # using JSON Schema 26 | function = { 27 | "name": "sentiment_analysis", 28 | "parameters": { 29 | "type": "object", 30 | "properties": { 31 | "answers": { 32 | "type": "array", 33 | "items": { 34 | "type": "object", 35 | "properties": { 36 | "id": {"type": "string"}, 37 | "sentiment": {"type": "string"}, 38 | }, 39 | "required": ["id", "sentiment"], 40 | }, 41 | }, 42 | }, 43 | "required": ["answers"], 44 | }, 45 | } 46 | 47 | system_message = "You are a sentiment analysis expert. Analyze the following statements as positive or negative." 48 | 49 | model = models.GPT_3_5_TURBO 50 | 51 | requests = chat.build_binpacked_requests( 52 | texts=input_texts, 53 | function=function, 54 | model=model, 55 | system_message=system_message, 56 | params=models.Parameters(max_tokens=50), 57 | ) 58 | 59 | # %% 60 | # Estimate the cost of the requests 61 | estimated_cost_usd = sum([r.estimate_cost_usd() for r in requests]) 62 | print(f"Estimated cost: ${estimated_cost_usd:.4f}") 63 | 64 | # %% 65 | # Requires that the OPENAI_API_KEY environment variable is set. 66 | responses = processor.process_api_requests( 67 | requests=requests, 68 | cache=cache, 69 | ) 70 | 71 | 72 | # %% 73 | results = [processor.parse_arguments(response=response) for response in responses] 74 | 75 | for text, answer in zip(input_texts, results[0]["answers"]): 76 | print(f"{text}: {answer['sentiment']}") 77 | 78 | # %% 79 | actual_cost_usd = sum( 80 | [ 81 | processor.usage_to_cost( 82 | usage=processor.parse_token_usage(response=response), 83 | model=model, 84 | ) 85 | for response in responses 86 | ] 87 | ) 88 | 89 | print(f"Actual cost: ${actual_cost_usd:.4f}") 90 | 91 | # %% 92 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog for texttunnel 2 | 3 | ## 0.3.7 4 | 5 | - Added model configurations for gpt-4-turbo and gpt-4o (note that users can add their own model configurations too) 6 | - Added an example for named entity recognition 7 | 8 | ## 0.3.6 9 | 10 | Bug fixes: 11 | 12 | - Fixed a bug that caused retry requests to be overwritten by new requests in `aprocess_api_requests`. 13 | 14 | ## 0.3.5 15 | 16 | Changes: 17 | 18 | - texttunnel can now be used with any OpenAI Chat model, including your own fine-tuned models. Previously only a limited number of models were allowed due to peculiarities of token counting. This change comes at the cost of the possibility of miscounting tokens by 1 token per message in a chat, in case OpenAI changes token counting in future models. See https://github.com/qagentur/texttunnel/pull/70 for details. 19 | - Requests now use a seed by default, which makes the results more consistent (see https://platform.openai.com/docs/guides/text-generation/reproducible-outputs). 20 | 21 | Documentation: 22 | 23 | - Documentation for changing API quota limits has been added to Sphinx docs. 24 | - Documentation on texttunnel's model class support has been added to Sphinx docs. 25 | 26 | ## 0.3.4 27 | 28 | Changes: 29 | 30 | - additional DEBUG level logs for cached requests 31 | 32 | Bug fixes: 33 | 34 | - `aprocess_api_requests` no longer gets stuck after a request fails 35 | - aiohttp sessions are now properly closed after an error occurs in the request 36 | 37 | ## 0.3.3 38 | 39 | Changes: 40 | 41 | - `aprocess_api_requests` now makes cache lookup asynchronously to improve performance 42 | - the package is now compatible with jsonschema 3.0.0 and up, previously it was only compatible with 4.0.0 and up 43 | 44 | Bug fixes: 45 | 46 | - `aprocess_api_requests` now properly closes the connection to the cache backend 47 | 48 | ## 0.3.2 49 | 50 | Changes: 51 | 52 | - `chat.build_requests` and `chat.build_binpacked_requests` now raise a ValueError when the text argument contains duplicates 53 | - `aprocess_api_requests` now raises a ValueError when the requests passed to it have duplicate hashes 54 | 55 | Both of these changes are to prevent waste of money on duplicate API requests. They also prevent a sorting error where results wouldn't be returned in the same order as the requests were passed in. 56 | 57 | Bug fixes: 58 | 59 | - Fixed a bug where aiohttp sessions were not closed when an error occurred in the request 60 | 61 | ## 0.3.1 62 | 63 | - Made `aprocess_api_requests()` independently useable to allow advanced users to take full control of the asyncio event loop. 64 | - Added text classification example. 65 | 66 | ## 0.3.0 67 | 68 | - Breaking: Replaced diskcache with aiohttp_client_cache for caching requests. This provides support for SQLite, Redis, DynamoDB and MongoDB cache backends. 69 | 70 | ## 0.2.3 71 | 72 | - Added support for gpt-3.5-turbo-16k 73 | 74 | ## 0.2.2 75 | 76 | - Initial release 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Cache files from cache backends 2 | *.sqlite 3 | 4 | # macOS 5 | .DS_Store 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | 168 | # VSCode 169 | .vscode/ 170 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # texttunnel: Efficient text processing with GPT-3.5 and GPT-4 2 | 3 |
4 | 5 |
6 | 7 | This package offers a straightforward interface for integrating the GPT-3.5 and GPT-4 models into your natural language processing pipelines. It is optimally designed for the following scenario: 8 | 9 | Suppose you possess a corpus of text data that you want to analyze using the GPT-3.5 or GPT-4 models. The goal is to perform extractive NLP tasks such as classification, named entity recognition, translation, summarization, question answering, or sentiment analysis. In this context, the package prioritizes efficiency and tidiness to provide you streamlined results. 10 | 11 | Features: 12 | 13 | - 📄 Output Schema: Utilizes [JSON Schema](https://json-schema.org) alongside OpenAI's function calling schema to define the output data structure. 14 | - ✔️ Input Validation: Ensures well-structured and error-free API requests by validating input data. 15 | - ✅ Output Validation: Checks the response data from OpenAI's API against the expected schema to maintain data integrity. 16 | - 🚦 Asynchronous Requests: Facilitates speedy data processing by sending simultaneous requests to OpenAI's API, while staying within API rate limits. 17 | - 🚀 Efficient Batching: Supports bulk processing by packing multiple input texts into a single request for the OpenAI's API. 18 | - 💰 Cost Estimation: Aims for transparency in API utilization cost by providing cost estimates before sending API requests. 19 | - 💾 Caching: Uses [aiohttp_client_cache](https://github.com/requests-cache/aiohttp-client-cache) to avoid redundant requests and reduce cost by caching previous requests. Supports SQLite, MongoDB, DynamoDB and Redis cache backends. 20 | - 📝 Request Logging: Implements Python's native [logging](https://docs.python.org/3/library/logging.html) framework for tracking and logging all API requests. 21 | 22 | Note that this package only works with [function calling](https://platform.openai.com/docs/guides/function-calling) and only with the OpenAI API. If you're looking for a more flexible solution, consider [instructor](https://github.com/jxnl/instructor) and [litellm](https://github.com/BerriAI/litellm). You might also consider using the [OpenAI Batch API](https://platform.openai.com/docs/api-reference/batch) as it offers savings compared to synchronous API calls. 23 | 24 | ⚠️ **Maintenance mode**: At this time no new features or enhancements are being developed. Only critical bugfixes will be made. 25 | 26 | ## Installation 27 | 28 | The package is available on [PyPI](https://pypi.org/project/texttunnel/). To install it, run: 29 | 30 | ```bash 31 | pip install texttunnel 32 | ``` 33 | 34 | or via poetry: 35 | 36 | ```bash 37 | poetry add texttunnel 38 | ``` 39 | 40 | **Note**: If you want to use caching, you need to install the aiohttp_client_cache extras. Please refer to the [aiohttp_client_cache](https://github.com/requests-cache/aiohttp-client-cache#quickstart) documentation for more information. 41 | 42 | ## Usage 43 | 44 | Check the docs: [https://qagentur.github.io/texttunnel/](https://qagentur.github.io/texttunnel/) 45 | 46 | Create an account on [OpenAI](https://openai.com) and get an API key. Set it as an environment variable called `OPENAI_API_KEY`. 47 | 48 | Check the [examples](examples) directory for examples of how to use this package. 49 | 50 | If your account has been granted higher rate limits than the ones configured in the models module, you can override the default attributes of the Model class instances. See documentation of the models package module. 51 | 52 | ## Development 53 | 54 | To get started with development, follow these steps: 55 | 56 | - clone the repository 57 | - install [poetry](https://python-poetry.org/docs/) if you don't have it yet 58 | - navigate to the project folder 59 | - run `poetry install` to install the dependencies 60 | - run the tests with `poetry run pytest -v` 61 | 62 | This project uses [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings and [black](https://github.com/psf/black) formatting. The docs are automatically built based on the docstrings. 63 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import itertools 3 | 4 | from texttunnel import utils 5 | 6 | 7 | def test_num_tokens_from_text(texts_fixture): 8 | num_tokens = [utils.num_tokens_from_text(text) for text in texts_fixture] 9 | assert num_tokens == [4, 0, 15, 7] 10 | 11 | 12 | def test_binpack_texts_in_order(texts_fixture, encoding_fixture): 13 | max_tokens_per_bin = 40 14 | text_bins = utils.binpack_texts_in_order( 15 | texts=texts_fixture, 16 | max_tokens_per_bin=max_tokens_per_bin, 17 | formatter_function=utils.format_texts_as_json, 18 | ) 19 | 20 | tokens_in_bins = [ 21 | len(encoding_fixture.encode(utils.format_texts_as_json(text_bin))) 22 | for text_bin in text_bins 23 | ] 24 | assert all([tokens <= max_tokens_per_bin for tokens in tokens_in_bins]) 25 | 26 | 27 | def test_binpack_texts_in_order_overhead_too_long_error(texts_fixture): 28 | with pytest.raises(ValueError): 29 | utils.binpack_texts_in_order( 30 | texts=texts_fixture, 31 | max_tokens_per_bin=5, # too small for overhead (12) 32 | formatter_function=utils.format_texts_as_json, 33 | ) 34 | 35 | 36 | def test_binpack_texts_in_order_text_too_long_error(texts_long_fixture): 37 | max_tokens_per_bin = 1000 38 | with pytest.raises(ValueError): # Doesn't fit due to overhead 39 | utils.binpack_texts_in_order( 40 | texts=[texts_long_fixture[-1]], # Last text has 1000 tokens 41 | max_tokens_per_bin=max_tokens_per_bin, 42 | formatter_function=utils.format_texts_as_json, 43 | long_text_handling="error", 44 | ) 45 | 46 | 47 | def test_binpack_texts_in_order_truncation(texts_fixture, encoding_fixture): 48 | max_tokens_per_bin = 25 49 | text_bins = utils.binpack_texts_in_order( 50 | texts=texts_fixture, 51 | max_tokens_per_bin=max_tokens_per_bin, 52 | formatter_function=utils.format_texts_as_json, 53 | long_text_handling="truncate", 54 | ) 55 | 56 | tokens_in_bins = [ 57 | len(encoding_fixture.encode(utils.format_texts_as_json(text_bin))) 58 | for text_bin in text_bins 59 | ] 60 | assert all([tokens <= max_tokens_per_bin for tokens in tokens_in_bins]) 61 | 62 | 63 | def test_binpack_texts_in_order_long_texts(texts_long_fixture, encoding_fixture): 64 | max_tokens_per_bin = 1013 # exactly fits the longest text including overhead 65 | text_bins = utils.binpack_texts_in_order( 66 | texts=texts_long_fixture, 67 | max_tokens_per_bin=max_tokens_per_bin, 68 | formatter_function=utils.format_texts_as_json, 69 | long_text_handling="error", 70 | ) 71 | 72 | # All texts should be in a bin 73 | flattened_bins = list(itertools.chain.from_iterable(text_bins)) 74 | assert len(flattened_bins) == len(texts_long_fixture) 75 | 76 | tokens_in_bins = [ 77 | len(encoding_fixture.encode(utils.format_texts_as_json(text_bin))) 78 | for text_bin in text_bins 79 | ] 80 | assert all([tokens <= max_tokens_per_bin for tokens in tokens_in_bins]) 81 | 82 | 83 | def test_binpack_texts_in_order_max_texts_per_bin(texts_long_fixture): 84 | max_tokens_per_bin = 10000 # very large 85 | max_texts_per_bin = 3 86 | 87 | text_bins = utils.binpack_texts_in_order( 88 | texts=texts_long_fixture, 89 | max_tokens_per_bin=max_tokens_per_bin, 90 | max_texts_per_bin=max_texts_per_bin, 91 | formatter_function=utils.format_texts_as_json, 92 | ) 93 | 94 | assert max([len(text_bin) for text_bin in text_bins]) == max_texts_per_bin 95 | 96 | 97 | def test_format_texts_as_json(texts_fixture): 98 | act = utils.format_texts_as_json(texts_fixture[:2]) 99 | exp = '[{"id": 0, "text": "The first text."}, {"id": 1, "text": ""}]' 100 | 101 | assert act == exp 102 | 103 | 104 | def test_format_texts_as_json_keeps_non_ascii_characters(texts_nonascii_fixture): 105 | act = utils.format_texts_as_json(texts_nonascii_fixture) 106 | exp = '[{"id": 0, "text": "Äpfel"}, {"id": 1, "text": "👋 🌍"}, {"id": 2, "text": "你好世界"}]' 107 | 108 | assert act == exp 109 | 110 | 111 | def test_truncate_text_by_tokens(encoding_fixture): 112 | text = "Hello, world!" 113 | truncated_text = utils.truncate_text_by_tokens( 114 | text, max_tokens=2, encoding=encoding_fixture 115 | ) 116 | assert truncated_text == "Hello," 117 | -------------------------------------------------------------------------------- /tests/test_chat.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from texttunnel import chat, models 4 | 5 | 6 | def test_chat_add_message(chat_fixture): 7 | chat_fixture.add_message(message=chat.ChatMessage(role="user", content="Hi!")) 8 | assert len(chat_fixture.messages) == 3 9 | assert chat_fixture.messages[2].content == "Hi!" 10 | 11 | 12 | def test_is_valid_function_def(function_fixture): 13 | assert chat.is_valid_function_def(function_fixture) 14 | 15 | bad_function = function_fixture.copy() 16 | del bad_function["name"] 17 | 18 | assert not chat.is_valid_function_def(bad_function) 19 | 20 | 21 | def test_chat(chat_fixture): 22 | chat = chat_fixture 23 | 24 | assert chat.messages[0].role == "system" 25 | 26 | 27 | def test_chat_completion_request(model_fixture, chat_fixture, function_fixture): 28 | request = chat.ChatCompletionRequest( 29 | model=model_fixture, 30 | chat=chat_fixture, 31 | function=function_fixture, 32 | params=models.Parameters(max_tokens=128, temperature=0.5), 33 | ) 34 | 35 | assert request.function_call == {"name": "function_name"} 36 | assert request.count_total_tokens() > 0 37 | assert request.estimate_cost_usd() > 0 38 | assert isinstance(request.to_dict(), dict) 39 | assert request.to_dict()["temperature"] == 0.5 40 | 41 | 42 | def test_chat_completion_request_context_size_exceeded( 43 | model_fixture, chat_fixture, function_fixture 44 | ): 45 | with pytest.raises(ValueError): 46 | chat.ChatCompletionRequest( 47 | model=model_fixture, 48 | chat=chat_fixture, 49 | function=function_fixture, 50 | params=models.Parameters(max_tokens=4080), # doesn't fit 51 | ) 52 | 53 | 54 | def test_build_binpacked_requests_default_settings( 55 | model_fixture, function_fixture, texts_long_fixture, params_fixture 56 | ): 57 | requests = chat.build_binpacked_requests( 58 | system_message="You are a helpful assistant.", 59 | model=model_fixture, 60 | function=function_fixture, 61 | texts=texts_long_fixture, 62 | params=params_fixture, 63 | ) 64 | 65 | assert all([r.count_total_tokens() <= model_fixture.context_size for r in requests]) 66 | 67 | 68 | def test_build_binpacked_requests_max_texts_per_request( 69 | model_fixture, 70 | function_fixture, 71 | texts_fixture, 72 | params_fixture, 73 | ): 74 | requests = chat.build_binpacked_requests( 75 | system_message="You are a helpful assistant.", 76 | model=model_fixture, 77 | function=function_fixture, 78 | texts=texts_fixture, 79 | max_texts_per_request=2, 80 | params=params_fixture, 81 | ) 82 | 83 | assert len(requests) == 2 84 | 85 | 86 | def test_build_requests( 87 | model_fixture, 88 | function_fixture, 89 | texts_fixture, 90 | params_fixture, 91 | ): 92 | requests = chat.build_requests( 93 | system_message="You are a helpful assistant.", 94 | model=model_fixture, 95 | function=function_fixture, 96 | texts=texts_fixture, 97 | params=params_fixture, 98 | ) 99 | 100 | assert len(requests) == len(texts_fixture) 101 | 102 | 103 | def test_build_requests_fails_on_duplicate_texts( 104 | model_fixture, 105 | function_fixture, 106 | texts_fixture, 107 | params_fixture, 108 | ): 109 | with pytest.raises(ValueError): 110 | chat.build_requests( 111 | system_message="You are a helpful assistant.", 112 | model=model_fixture, 113 | function=function_fixture, 114 | texts=[texts_fixture[0], texts_fixture[0]], 115 | params=params_fixture, 116 | ) 117 | 118 | 119 | def test_chat_completion_request_context_size_check( 120 | chat_fixture, function_fixture, params_fixture 121 | ): 122 | tiny_model = chat.Model( 123 | name="gpt-3.5-turbo", 124 | context_size=1, # only for testing, real context size is 4096 125 | input_token_price_per_1k=0.002, 126 | output_token_price_per_1k=0.004, 127 | tokens_per_minute=90000, 128 | requests_per_minute=3500, 129 | ) 130 | 131 | with pytest.raises(ValueError): 132 | chat.ChatCompletionRequest( 133 | model=tiny_model, 134 | chat=chat_fixture, 135 | function=function_fixture, 136 | params=params_fixture, 137 | ) 138 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. texttunnel documentation master file, created by 2 | sphinx-quickstart on Fri Aug 18 14:26:34 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | texttunnel: Efficient text processing with GPT-3.5 and GPT-4 7 | ============================================================ 8 | 9 | This package offers a straightforward interface for integrating the GPT-3.5 and GPT-4 models into your natural language processing pipelines. It is optimally designed for the following scenario: 10 | 11 | Suppose you possess a corpus of text data that you want to analyze using the GPT-3.5 or GPT-4 models. The goal is to perform extractive NLP tasks such as classification, named entity recognition, translation, summarization, question answering, or sentiment analysis. In this context, the package prioritizes efficiency and tidiness to provide you streamlined results. 12 | 13 | Features: 14 | 15 | - 📄 Output Schema: Utilizes JSON Schema alongside OpenAI's function calling schema to define the output data structure. 16 | - ✔️ Input Validation: Ensures well-structured and error-free API requests by validating input data. 17 | - ✅ Output Validation: Checks the response data from OpenAI's API against the expected schema to maintain data integrity. 18 | - 🚀 Efficient Batching: Supports bulk processing by packing multiple input texts into a single request for the OpenAI's API. 19 | - 🚦 Asynchronous Requests: Facilitates speedy data processing by sending simultaneous requests to OpenAI's API, while maintaining API rate limits. 20 | - 💰 Cost Estimation: Aims for transparency in API utilization cost by providing cost estimates before sending API requests. 21 | - 💾 Caching: Uses aiohttp_client_cach to avoid redundant requests and reduce cost by caching previous requests. Supports SQLite, MongoDB, DynamoDB and Redis cache backends. 22 | - 📝 Request Logging: Implements Python's native logging framework for tracking and logging all API requests. 23 | 24 | To get started, check the examples: 25 | https://github.com/qagentur/texttunnel/tree/main/examples 26 | 27 | OpenAI's function calling guide is also a useful resource: 28 | https://platform.openai.com/docs/guides/gpt/function-calling 29 | 30 | .. toctree:: 31 | :maxdepth: 2 32 | :caption: Contents: 33 | 34 | Modules 35 | ======= 36 | 37 | Chat Module 38 | ^^^^^^^^^^^ 39 | .. automodule:: texttunnel.chat 40 | :members: 41 | 42 | Models Module 43 | ^^^^^^^^^^^^^ 44 | .. automodule:: texttunnel.models 45 | :members: 46 | 47 | .. autoattribute:: texttunnel.models.GPT_4 48 | .. autoattribute:: texttunnel.models.GPT_4_0613 49 | .. autoattribute:: texttunnel.models.GPT_4_32K 50 | .. autoattribute:: texttunnel.models.GPT_4_32K_0613 51 | .. autoattribute:: texttunnel.models.GPT_4_0314 52 | .. autoattribute:: texttunnel.models.GPT_4_32K_0314 53 | .. autoattribute:: texttunnel.models.GPT_3_5_TURBO 54 | .. autoattribute:: texttunnel.models.GPT_3_5_TURBO_16K 55 | .. autoattribute:: texttunnel.models.GPT_3_5_TURBO_0613 56 | .. autoattribute:: texttunnel.models.GPT_3_5_TURBO_16K_0613 57 | .. autoattribute:: texttunnel.models.GPT_3_5_TURBO_0301 58 | 59 | Models that are not included here can be created as custom instances of the Model class. Only chat models are supported; "instruct" models are not supported. 60 | 61 | Preview models can be used, but will not be added as default models to the package. To use a preview model, create a custom instance of the Model class. Models that OpenAI deprecates will be removed from the package. This primarily affects date-versioned models. 62 | 63 | Note that the model class attributes tokens_per_minute (TPM) and requests_per_minute (RPM) are based on tier 1 usage limits. See https://platform.openai.com/docs/guides/rate-limits?context=tier-free for more details. If your account has a higher usage tier, override the class attributes with your own values. 64 | 65 | texttunnel does not track tokens_per_day (TPD) limits and assumes that it is the only process that is using your model quota. 66 | 67 | Processor Module 68 | ^^^^^^^^^^^^^^^^ 69 | .. automodule:: texttunnel.processor 70 | :members: 71 | 72 | Utils Module 73 | ^^^^^^^^^^^^ 74 | .. automodule:: texttunnel.utils 75 | :members: 76 | 77 | Logging 78 | ======= 79 | 80 | The package uses the standard logging library and creates a logger named "texttunnel". 81 | 82 | To enable logging, add the following code to your script: 83 | 84 | .. code-block:: python 85 | 86 | import logging 87 | logging.basicConfig(level=logging.WARNING) # choose whatever level you want 88 | logging.getLogger("texttunnel").setLevel(logging.INFO) # set to DEBUG for more verbose logging 89 | 90 | 91 | Indices and tables 92 | ================== 93 | 94 | * :ref:`genindex` 95 | * :ref:`modindex` 96 | * :ref:`search` 97 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Test fixtures shared across test files 2 | 3 | import pytest 4 | import json 5 | 6 | import tiktoken 7 | 8 | from texttunnel import chat, models 9 | 10 | 11 | @pytest.fixture 12 | def texts_fixture(): 13 | return [ 14 | "The first text.", 15 | "", # empty string 16 | "The third text has non-ASCII characters: 你好世界", # hello world in Chinese 17 | "The fourth text has a newline.\n", 18 | ] 19 | 20 | 21 | @pytest.fixture 22 | def texts_long_fixture(): 23 | n_texts = 100 24 | min_length = 10 25 | max_length = 1000 26 | 27 | text_lengths = range( 28 | min_length, max_length + 10, 10 # range is not inclusive 29 | ) # 100 variations of text length 30 | 31 | j = 0 32 | 33 | # Cycle through texts lengths to create a list of texts 34 | texts = [] 35 | for _ in range(n_texts): 36 | text_length = text_lengths[j] 37 | text = " ".join(["hello"] * text_length) # Nirvana lyrics generator 38 | texts.append(text) 39 | if j < len(text_lengths) - 1: 40 | j += 1 41 | else: 42 | j = 0 43 | 44 | return texts 45 | 46 | 47 | @pytest.fixture 48 | def texts_nonascii_fixture(): 49 | return [ 50 | "Äpfel", # apples in German 51 | "👋 🌍", 52 | "你好世界", # hello world in Chinese 53 | ] 54 | 55 | 56 | @pytest.fixture 57 | def encoding_fixture(): 58 | return tiktoken.get_encoding("cl100k_base") 59 | 60 | 61 | @pytest.fixture 62 | def requests_fixture(): 63 | function_def: chat.FunctionDef = { 64 | "name": "tell_feeling", 65 | "parameters": { 66 | "type": "object", 67 | "properties": {"feeling": {"type": "string"}}, 68 | }, 69 | } 70 | 71 | return chat.build_requests( 72 | system_message="You are a helpful assistant.", 73 | model=models.GPT_3_5_TURBO, 74 | function=function_def, 75 | texts=["I am happy.", "I am sad."], 76 | params=models.Parameters(max_tokens=128), 77 | ) 78 | 79 | 80 | @pytest.fixture 81 | def response_fixture(): 82 | return [ 83 | { 84 | "model": "gpt-3.5-turbo", 85 | "max_tokens": 50, 86 | "messages": [ 87 | { 88 | "role": "system", 89 | "content": "You are a helpful assistant", 90 | }, 91 | { 92 | "role": "user", 93 | "content": "How are you?", 94 | }, 95 | ], 96 | "functions": [ 97 | { 98 | "name": "tell_feeling", 99 | "parameters": { 100 | "type": "object", 101 | "properties": { 102 | "feeling": { 103 | "type": "string", 104 | "enum": ["happy", "sad", "angry"], 105 | } 106 | }, 107 | }, 108 | } 109 | ], 110 | }, 111 | { 112 | "id": "chatcmpl-7nQcrnnrqATiOktw8nY0AsbfGXqrn", 113 | "object": "chat.completion", 114 | "created": 1692014777, 115 | "model": "gpt-3.5-turbo", 116 | "choices": [ 117 | { 118 | "index": 0, 119 | "message": { 120 | "role": "assistant", 121 | "content": None, 122 | "function_call": { 123 | "name": "tell_feeling", 124 | "arguments": json.dumps({"feeling": "happy"}), 125 | }, 126 | }, 127 | "finish_reason": "stop", 128 | }, 129 | ], 130 | "usage": { 131 | "prompt_tokens": 100, 132 | "completion_tokens": 50, 133 | "total_tokens": 150, 134 | }, 135 | }, 136 | ] 137 | 138 | 139 | @pytest.fixture 140 | def chat_fixture(): 141 | return chat.Chat( 142 | messages=[ 143 | chat.ChatMessage( 144 | role="system", 145 | content="You are a helpful assistant.", 146 | ), 147 | chat.ChatMessage( 148 | role="user", 149 | content="Hello, world!", 150 | ), 151 | ] 152 | ) 153 | 154 | 155 | @pytest.fixture 156 | def function_fixture(): 157 | return { 158 | "name": "function_name", 159 | "parameters": { 160 | "type": "object", 161 | "properties": {"argument1": {"type": "string"}}, 162 | }, 163 | } 164 | 165 | 166 | @pytest.fixture 167 | def model_fixture(): 168 | return models.Model( 169 | name="gpt-3.5-turbo", 170 | context_size=4096, 171 | input_token_price_per_1k=0.002, 172 | output_token_price_per_1k=0.004, 173 | tokens_per_minute=90000, 174 | requests_per_minute=3500, 175 | ) 176 | 177 | 178 | @pytest.fixture 179 | def params_fixture(): 180 | return models.Parameters(max_tokens=128) 181 | -------------------------------------------------------------------------------- /texttunnel/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Model: 6 | """ 7 | Information about an OpenAI ChatCompletion model. 8 | Check prices here: https://openai.com/pricing 9 | 10 | Note that rate limits differ between OpenAI accounts. 11 | Check them here: https://platform.openai.com/account/rate-limits 12 | 13 | Args: 14 | name: The name of the model, e.g. "gpt-3.5-turbo". 15 | context_size: The maximum number of tokens that can be passed to the model. 16 | input_token_price_per_1k: The price in USD per 1000 tokens for input. 17 | output_token_price_per_1k: The price in USD per 1000 tokens for output. 18 | tokens_per_minute: The maximum number of tokens that can be processed per minute. 19 | Note that this may differ between OpenAI accounts. Override the default 20 | models' values with your own values. 21 | requests_per_minute: The maximum number of requests that can be made per minute. 22 | Note that this may differ between OpenAI accounts. Override the default 23 | models' values with your own values. 24 | """ 25 | 26 | name: str 27 | context_size: int 28 | input_token_price_per_1k: float 29 | output_token_price_per_1k: float 30 | tokens_per_minute: int 31 | requests_per_minute: int 32 | 33 | def __post_init__(self): 34 | # Check that inputs are positive 35 | 36 | for arg in [ 37 | "context_size", 38 | "input_token_price_per_1k", 39 | "output_token_price_per_1k", 40 | "tokens_per_minute", 41 | "requests_per_minute", 42 | ]: 43 | if getattr(self, arg) < 0: 44 | raise ValueError(f"{arg} must be positive") 45 | 46 | 47 | # Look up information on models, pricing and rate limits 48 | # Note that tokens_per_minute and requests_per_minute differ by usage tier 49 | # https://platform.openai.com/docs/models/overview 50 | # https://openai.com/pricing 51 | # https://platform.openai.com/docs/guides/rate-limits/usage-tiers 52 | 53 | GPT_4o = Model( 54 | name="gpt-4o", 55 | context_size=128000, 56 | input_token_price_per_1k=0.005, 57 | output_token_price_per_1k=0.0015, 58 | tokens_per_minute=30000, 59 | requests_per_minute=500, 60 | ) 61 | 62 | GPT_4_TURBO = Model( 63 | name="gpt-4-turbo", 64 | context_size=128000, 65 | input_token_price_per_1k=0.01, 66 | output_token_price_per_1k=0.03, 67 | tokens_per_minute=30000, 68 | requests_per_minute=500, 69 | ) 70 | 71 | GPT_4 = Model( 72 | name="gpt-4", 73 | context_size=8192, 74 | input_token_price_per_1k=0.03, 75 | output_token_price_per_1k=0.06, 76 | tokens_per_minute=10000, 77 | requests_per_minute=500, 78 | ) 79 | 80 | GPT_4_0613 = Model( 81 | name="gpt-4-0613", 82 | context_size=8192, 83 | input_token_price_per_1k=0.03, 84 | output_token_price_per_1k=0.06, 85 | tokens_per_minute=10000, 86 | requests_per_minute=500, 87 | ) 88 | 89 | GPT_4_32K = Model( 90 | name="gpt-4-32k", 91 | context_size=32768, 92 | input_token_price_per_1k=0.06, 93 | output_token_price_per_1k=0.12, 94 | tokens_per_minute=20000, 95 | requests_per_minute=500, 96 | ) 97 | 98 | GPT_4_32K_0613 = Model( 99 | name="gpt-4-32k-0613", 100 | context_size=32768, 101 | input_token_price_per_1k=0.06, 102 | output_token_price_per_1k=0.12, 103 | tokens_per_minute=20000, 104 | requests_per_minute=500, 105 | ) 106 | 107 | # legacy 108 | GPT_4_0314 = Model( 109 | name="gpt-4-0314", 110 | context_size=8192, 111 | input_token_price_per_1k=0.03, 112 | output_token_price_per_1k=0.06, 113 | tokens_per_minute=10000, 114 | requests_per_minute=500, 115 | ) 116 | 117 | # legacy 118 | GPT_4_32K_0314 = Model( 119 | name="gpt-4-32k-0314", 120 | context_size=32768, 121 | input_token_price_per_1k=0.06, 122 | output_token_price_per_1k=0.12, 123 | tokens_per_minute=10000, 124 | requests_per_minute=500, 125 | ) 126 | 127 | GPT_3_5_TURBO = Model( 128 | name="gpt-3.5-turbo", 129 | context_size=4096, 130 | input_token_price_per_1k=0.0015, 131 | output_token_price_per_1k=0.002, 132 | tokens_per_minute=90000, 133 | requests_per_minute=3500, 134 | ) 135 | 136 | GPT_3_5_TURBO_16K = Model( 137 | name="gpt-3.5-turbo-16k", 138 | context_size=16384, 139 | input_token_price_per_1k=0.003, 140 | output_token_price_per_1k=0.004, 141 | tokens_per_minute=180000, 142 | requests_per_minute=3500, 143 | ) 144 | 145 | GPT_3_5_TURBO_0613 = Model( 146 | name="gpt-3.5-turbo-0613", 147 | context_size=4096, 148 | input_token_price_per_1k=0.0015, 149 | output_token_price_per_1k=0.002, 150 | tokens_per_minute=90000, 151 | requests_per_minute=3500, 152 | ) 153 | 154 | GPT_3_5_TURBO_16K_0613 = Model( 155 | name="gpt-3.5-turbo-16k-0613", 156 | context_size=16384, 157 | input_token_price_per_1k=0.003, 158 | output_token_price_per_1k=0.004, 159 | tokens_per_minute=180000, 160 | requests_per_minute=3500, 161 | ) 162 | 163 | # legacy 164 | GPT_3_5_TURBO_0301 = Model( 165 | name="gpt-3.5-turbo-0301", 166 | context_size=4096, 167 | input_token_price_per_1k=0.0015, 168 | output_token_price_per_1k=0.002, 169 | tokens_per_minute=9000, 170 | requests_per_minute=3500, 171 | ) 172 | 173 | 174 | class Parameters: 175 | """ 176 | Set of parameters that can be passed to an API request. 177 | 178 | The parameters are explained in the OpenAI API documentation: 179 | https://platform.openai.com/docs/api-reference/chat/create 180 | 181 | Args: 182 | max_tokens: The maximum number of tokens to generate. Note: 183 | This can't be greater than the model's context size and should be at least 184 | long enough to fit the whole expected JSON output. This parameter is used 185 | to estimate the cost of the request. 186 | temperature: What sampling temperature to use, between 0 and 2. 187 | Higher values like 0.8 will make the output more random, while 188 | lower values like 0.2 will make it more focused and deterministic. 189 | Defaults to 0.0 because this package is designed for deterministic 190 | JSON-schema compliant output. 191 | presence_penalty: Number between -2.0 and 2.0. Positive values penalize 192 | new tokens based on whether they appear in the text so far, 193 | increasing the model's likelihood to talk about new topics. Defaults to 0.0. 194 | frequency_penalty: Number between -2.0 and 2.0. Positive values penalize 195 | new tokens based on their existing frequency in the text so far, 196 | decreasing the model's likelihood to repeat the same line verbatim. 197 | Defaults to 0.0. 198 | seed: Integer seed for random number generation. Defaults to 42. 199 | 200 | Parameters that are not listed here are not supported by this package. The 201 | reason is that they're not relevant for the use case of this package. 202 | """ 203 | 204 | def __init__( 205 | self, 206 | max_tokens: int, 207 | temperature: float = 0.0, 208 | presence_penalty: float = 0.0, 209 | frequency_penalty: float = 0.0, 210 | seed: int = 42, 211 | ): 212 | if max_tokens < 1: 213 | raise ValueError("max_tokens must be positive") 214 | 215 | if temperature < 0 or temperature > 1: 216 | raise ValueError("temperature must be between 0 and 1") 217 | 218 | if frequency_penalty < -2 or frequency_penalty > 2: 219 | raise ValueError("frequency_penalty must be between -2 and 2") 220 | 221 | if presence_penalty < -2 or presence_penalty > 2: 222 | raise ValueError("presence_penalty must be between -2 and 2") 223 | 224 | self.max_tokens = max_tokens 225 | self.temperature = temperature 226 | self.presence_penalty = presence_penalty 227 | self.frequency_penalty = frequency_penalty 228 | self.seed = seed 229 | 230 | def to_dict(self): 231 | """ 232 | Returns: 233 | A dictionary representation of the parameters. 234 | """ 235 | 236 | return { 237 | "max_tokens": self.max_tokens, 238 | "temperature": self.temperature, 239 | "presence_penalty": self.presence_penalty, 240 | "frequency_penalty": self.frequency_penalty, 241 | "seed": self.seed, 242 | } 243 | -------------------------------------------------------------------------------- /texttunnel/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from hashlib import sha256 3 | from typing import Callable, List, Optional 4 | 5 | import tiktoken 6 | 7 | 8 | def num_tokens_from_text(text: str, encoding_name: str = "cl100k_base") -> int: 9 | """ 10 | Returns the number of tokens in a string. 11 | 12 | Args: 13 | text: The text to count tokens in. 14 | encoding_name: The name of the token encoding to use. Defaults to "cl100k_base". 15 | 16 | Returns: 17 | The number of tokens in the string. 18 | """ 19 | encoding = tiktoken.get_encoding(encoding_name) 20 | num_tokens = len(encoding.encode(text)) 21 | return num_tokens 22 | 23 | 24 | def truncate_text_by_tokens( 25 | text: str, 26 | max_tokens: int, 27 | encoding: tiktoken.core.Encoding, 28 | ) -> str: 29 | """ 30 | Truncates a text to a maximum number of tokens. 31 | 32 | Args: 33 | text: The text to truncate. 34 | max_tokens: The maximum number of tokens to truncate the text to. 35 | encoding: The encoding to use. 36 | 37 | Returns: 38 | The truncated text. 39 | """ 40 | tokens = encoding.encode(text) 41 | truncated_tokens = tokens[:max_tokens] 42 | truncated_text = encoding.decode(truncated_tokens) 43 | 44 | return truncated_text 45 | 46 | 47 | def format_texts_as_json(texts: List[str]) -> str: 48 | """ 49 | Formats a list of texts into a single string to be used as a user message. 50 | Each text is assigned an ID, starting from 0. The returned JSON format 51 | helps the model distinguish between different texts, at the cost of 52 | increasing the number of tokens used. 53 | 54 | The token overhead for a single text that doesn't require escaping characters 55 | is 12 tokens. Escaping characters like quotes increases the overhead. 56 | 57 | The format is a JSON list of dictionaries, where each dictionary has an 58 | "id" key and a "text" key. The "id" key is an integer, and the "text" key 59 | is a string. This array of maps structure is easiest to parse by GPT models 60 | and handles edge cases like newlines in the text. 61 | 62 | Args: 63 | texts: A list of texts to format. 64 | 65 | Returns: 66 | A formatted string that can be used as a user message. 67 | """ 68 | 69 | if not isinstance(texts, list): 70 | raise ValueError("texts must be a list.") 71 | 72 | text_dicts = [ 73 | { 74 | "id": i, 75 | "text": text, 76 | } 77 | for i, text in enumerate(texts) 78 | ] 79 | 80 | # json.dumps escapes characters like quotes, which increases the token overhead 81 | # OpenAI models understand non-ascii characters, so we can use ensure_ascii=False 82 | # to avoid escaping characters. 83 | return json.dumps(text_dicts, ensure_ascii=False) 84 | 85 | 86 | def format_texts_with_spaces(texts: List[str]) -> str: 87 | """ 88 | Simple formatter that joins texts with spaces. 89 | """ 90 | return " ".join(texts) 91 | 92 | 93 | def binpack_texts_in_order( 94 | texts: List[str], 95 | formatter_function: Callable[[List[str]], str], 96 | max_tokens_per_bin: int, 97 | max_texts_per_bin: Optional[int] = None, 98 | encoding_name: str = "cl100k_base", 99 | long_text_handling: str = "error", 100 | ) -> List[List[str]]: 101 | """ 102 | Binpacks a list of texts into a list of lists of texts, such that each list of texts 103 | has a total number of tokens less than or equal to max_tokens_per_bin and each list of texts 104 | has a number of texts less than or equal to max_texts_per_bin. 105 | 106 | The binpacking uses a naive greedy algorithm that maintains the order of the texts. 107 | 108 | Args: 109 | texts: List of texts to binpack. Empty texts are accepted, counted as 0 tokens 110 | each and count against max_texts_per_bin. 111 | formatter_function: A function that takes a list of texts and returns a single 112 | text. Defaults to None, which means that the texts are joined with spaces. 113 | This function is used to include the overhead of the formatter function in 114 | the binpacking. It is not used to format the output. Make sure to use 115 | the same formatter function when formatting the output for the model. 116 | max_tokens_per_bin: The maximum number of tokens per bin of formatted texts. 117 | Leave some room for relative to the model's context size to account for the tokens in the 118 | system message, function call, and function return. 119 | max_texts_per_bin: The maximum number of texts per list of texts. Defaults to None, which 120 | means that there is no limit on the number of texts per list of texts. 121 | encoding_name: The name of the encoding to use. Defaults to "cl100k_base". 122 | long_text_handling: How to handle texts that are longer than max_tokens_per_bin. Defaults 123 | to "error", which means that an error is raised. Can also be set to 124 | "truncate", which means that the text is truncated to max_tokens_per_bin. 125 | It is possible that more tokens are truncated than absolutely necessary 126 | due to overhead of the formatter function caused by escaping characters. 127 | 128 | Returns: 129 | A list of lists of texts. The order of the texts is preserved. 130 | """ 131 | 132 | if not isinstance(texts, list): 133 | raise ValueError("texts must be a list.") 134 | 135 | if not max_texts_per_bin: 136 | max_texts_per_bin = len(texts) 137 | 138 | if max_texts_per_bin < 1: 139 | raise ValueError( 140 | f"max_texts_per_bin must be at least 1, but got {max_texts_per_bin}" 141 | ) 142 | 143 | encoding = tiktoken.get_encoding(encoding_name) 144 | 145 | # Binpack the texts 146 | # Initialize the first bin 147 | bins = [] 148 | current_bin = [] 149 | 150 | for i, text in enumerate(texts): 151 | if len(current_bin) == max_texts_per_bin: 152 | # Start a new bin 153 | bins.append(current_bin) 154 | current_bin = [] 155 | 156 | # Calculate how many tokens would be in the current bin if we added the text 157 | bin_tokens_with_new_text = len( 158 | encoding.encode(formatter_function(current_bin + [text])) 159 | ) 160 | 161 | if bin_tokens_with_new_text > max_tokens_per_bin: # doesn't fit 162 | if len(current_bin) > 0: 163 | # Start a new bin 164 | bins.append(current_bin) 165 | current_bin = [] 166 | 167 | # Check if the text fits in a bin by itself 168 | tokens_text_with_formatting = len( 169 | encoding.encode(formatter_function([text])) 170 | ) 171 | 172 | if tokens_text_with_formatting > max_tokens_per_bin: # doesn't fit 173 | # Calculate the overhead of the formatter function 174 | tokens_text_raw = len(encoding.encode(text)) 175 | overhead = tokens_text_with_formatting - tokens_text_raw 176 | 177 | if overhead > max_tokens_per_bin: 178 | raise ValueError( 179 | f""" 180 | The formatting function adds {overhead} overhead tokens, 181 | which exceeds the maximum number of tokens ({max_tokens_per_bin}) permitted. 182 | """ 183 | ) 184 | 185 | if bin_tokens_with_new_text > max_tokens_per_bin: 186 | # The formatted text is too long to fit in a bin 187 | if long_text_handling == "error": 188 | raise ValueError( 189 | f""" 190 | The text at index {i} has {tokens_text_with_formatting} tokens, which 191 | is greater than the maximum number of tokens ({max_tokens_per_bin}). 192 | Note that a formatting function added {overhead} tokens to the text. 193 | """ 194 | ) 195 | 196 | elif long_text_handling == "truncate": 197 | # Truncate the text, accounting for overhead 198 | # It's possible that more is truncated than necessary 199 | # in case the overhead was caused by escaping characters 200 | # in the truncated part of the text 201 | text = truncate_text_by_tokens( 202 | text=text, 203 | max_tokens=max_tokens_per_bin - overhead, 204 | encoding=encoding, 205 | ) 206 | 207 | assert ( 208 | len(encoding.encode(formatter_function([text]))) 209 | <= max_tokens_per_bin 210 | ) 211 | 212 | else: 213 | raise ValueError( 214 | f""" 215 | Invalid value for long_text_handling: {long_text_handling}. 216 | Must be one of "error" or "truncate". 217 | """ 218 | ) 219 | 220 | # Add to the current bin 221 | current_bin.append(text) 222 | 223 | # Add the last bin 224 | bins.append(current_bin) 225 | 226 | return bins 227 | 228 | 229 | def hash_dict(d: dict) -> str: 230 | """ 231 | Hashes a dictionary using sha256. 232 | """ 233 | return sha256(json.dumps(d).encode("utf-8")).hexdigest() 234 | -------------------------------------------------------------------------------- /texttunnel/chat.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Callable, Dict, List, Optional 3 | 4 | import tiktoken 5 | from jsonschema import Draft7Validator, exceptions 6 | 7 | from texttunnel import utils 8 | from texttunnel.models import Model, Parameters 9 | 10 | FunctionDef = Dict[str, str] 11 | 12 | 13 | def is_valid_function_def(function: FunctionDef) -> bool: 14 | """ 15 | Checks if a function definition is valid for use in a ChatCompletionRequest. 16 | Note that the parameter properties are not validated to allow for custom properties. 17 | 18 | Check the OpenAI API documentation for more information: 19 | https://platform.openai.com/docs/guides/gpt/function-calling 20 | 21 | Args: 22 | function: The function definition to validate. 23 | """ 24 | base_schema = { 25 | "name": {"type": "string"}, 26 | "description": {"type": "string"}, 27 | "parameters": { 28 | "type": "object", 29 | "properties": {"type": "object"}, 30 | }, 31 | "required": ["name", "parameters"], 32 | } 33 | 34 | try: 35 | Draft7Validator(base_schema).validate(function) 36 | except exceptions.ValidationError: 37 | print(f"Validation error: {exceptions.ValidationError}") 38 | return False 39 | 40 | return True 41 | 42 | 43 | class ChatMessage: 44 | """ 45 | A chat message, to be used in a chat. 46 | 47 | Args: 48 | role: The role of the message. Must be one of "system", "user", or "assistant". 49 | content: The content of the message. 50 | """ 51 | 52 | VALID_ROLES = {"system", "user", "assistant"} 53 | 54 | def __init__(self, role: str, content: str): 55 | if role not in self.VALID_ROLES: 56 | raise ValueError(f"Invalid role {role}. Must be one of {self.VALID_ROLES}.") 57 | 58 | self.role = role 59 | self.content = content 60 | 61 | def to_dict(self) -> Dict[str, str]: 62 | """ 63 | Returns a dict representation of the message. 64 | """ 65 | return {"role": self.role, "content": self.content} 66 | 67 | 68 | class Chat: 69 | """ 70 | A chat. Used to prompt a model for a response. 71 | The first message must be from the system, and the last message must be from the user. 72 | 73 | Args: 74 | messages: A list of ChatMessage objects. 75 | """ 76 | 77 | def __init__(self, messages: List[ChatMessage]): 78 | if len(messages) < 2: 79 | raise ValueError("A chat must have at least two messages.") 80 | 81 | if messages[0].role != "system": 82 | raise ValueError("The first message in a chat must be from the system.") 83 | 84 | if messages[-1].role != "user": 85 | raise ValueError("The last message in a chat must be from the user.") 86 | 87 | self.messages = messages 88 | 89 | def __len__(self) -> int: 90 | """ 91 | Returns the number of messages in the chat. 92 | """ 93 | return len(self.messages) 94 | 95 | def add_message(self, message: ChatMessage) -> None: 96 | """ 97 | Adds a message to the end of the chat. 98 | 99 | Args: 100 | message: The message to add. 101 | """ 102 | self.messages.append(message) 103 | 104 | def to_list(self) -> List[Dict[str, str]]: 105 | """ 106 | Returns a list of dictionaries representing the chat messages. 107 | This is the format expected by the OpenAI API. 108 | """ 109 | return [message.to_dict() for message in self.messages] 110 | 111 | def count_tokens(self) -> int: 112 | """ 113 | Return the number of tokens used. 114 | Note that this depends on the model used. Models that are not versioned 115 | with a date can change over time, causing an inaccurate token count 116 | by this function. 117 | 118 | Args: 119 | model: The name of the model to use. Defaults to "gpt-3.5-turbo-0613". 120 | 121 | Returns: 122 | The number of tokens used. 123 | """ 124 | 125 | # See reference implementation: 126 | # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb 127 | 128 | # Note that the reference implementation uses varying numbers of tokens 129 | # for tokens_per_message depending on model. At the time of writing, 130 | # only gpt-3.5-turbo-0301 differs from the rest by one token per message. 131 | # To allow any OpenAI model to be used, we use 3 tokens per message. 132 | # This causes an underestimation of the token count when using gpt-3.5-turbo-0301. 133 | 134 | encoding = tiktoken.get_encoding("cl100k_base") 135 | tokens_per_message = 3 136 | tokens_per_name = 1 137 | num_tokens = 0 138 | 139 | for message in self.messages: 140 | num_tokens += tokens_per_message 141 | num_tokens += len(encoding.encode(message.content)) 142 | if message.role == "name": 143 | num_tokens += tokens_per_name 144 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 145 | 146 | return num_tokens 147 | 148 | 149 | class ChatCompletionRequest: 150 | """ 151 | Defines a request for a chat completion. 152 | 153 | Args: 154 | chat: The chat to which the assistant should respond with a function call. 155 | model: The name of the OpenAI ChatCompletion model to use for completion. 156 | function: The function definition to use for the assistant's response. 157 | Must be a dictionary that describes a valid JSON schema. 158 | See https://platform.openai.com/docs/guides/gpt/function-calling 159 | params: Object of class Parameters. See models.Parameters for details. 160 | """ 161 | 162 | def __init__( 163 | self, 164 | chat: Chat, 165 | model: Model, 166 | function: FunctionDef, 167 | params: Parameters, 168 | ): 169 | self.chat = chat 170 | self.model = model 171 | 172 | if not is_valid_function_def(function): 173 | raise ValueError("Invalid function definition.") 174 | 175 | self.function = function 176 | 177 | # Force the model to use a function call 178 | self.functions = [function] 179 | self.function_call = {"name": function["name"]} 180 | 181 | if params.max_tokens > self.model.context_size: 182 | raise ValueError( 183 | f""" 184 | max_tokens ({params.max_tokens}) exceeds the context 185 | size of the model ({self.model.context_size}). 186 | """ 187 | ) 188 | 189 | self.params = params 190 | 191 | # Check that the inputs fit into the context size and leaves 192 | # enough space for the output 193 | num_prompt_tokens = self.count_prompt_tokens() 194 | num_completion_tokens = self.count_completion_tokens() 195 | num_total_tokens = num_prompt_tokens + num_completion_tokens 196 | 197 | if num_total_tokens > self.model.context_size: 198 | raise ValueError( 199 | f""" 200 | Total number of tokens ({num_total_tokens}) exceeds the context 201 | size of the model ({self.model.context_size}). Input tokens: 202 | {num_prompt_tokens}. Output tokens: {num_completion_tokens}. 203 | Context size: {self.model.context_size}. 204 | """ 205 | ) 206 | 207 | def to_dict(self) -> Dict[str, Any]: 208 | """ 209 | Returns a dictionary representation of the request. Only includes 210 | the elements that are required by the OpenAI API. Model parameters 211 | are flattened into the top-level dictionary. 212 | """ 213 | return { 214 | "model": self.model.name, 215 | "messages": self.chat.to_list(), 216 | "functions": self.functions, 217 | "function_call": self.function_call, 218 | **self.params.to_dict(), 219 | } 220 | 221 | def get_hash(self) -> str: 222 | """ 223 | Returns the hash of the request. Can be used as a cache key. 224 | """ 225 | return utils.hash_dict(self.to_dict()) 226 | 227 | def count_prompt_tokens(self) -> int: 228 | """ 229 | Counts the number of tokens that will be used as input to the model. 230 | This includes the chat messages and the function call. 231 | """ 232 | chat_tokens = self.chat.count_tokens() 233 | function_tokens = utils.num_tokens_from_text(json.dumps(self.functions[0])) 234 | 235 | return chat_tokens + function_tokens 236 | 237 | def count_completion_tokens(self) -> int: 238 | """ 239 | Counts the number of tokens that will be used as output of the model. 240 | Assumes that the model will return the maximum number of tokens allowed 241 | by the max_tokens parameter. 242 | """ 243 | 244 | return self.params.max_tokens 245 | 246 | def count_total_tokens(self) -> int: 247 | """ 248 | Counts the total number of tokens that will be used as input and output 249 | of the model. Assumes that the model will return the maximum number of 250 | tokens allowed by the max_tokens parameter. 251 | """ 252 | return self.count_prompt_tokens() + self.count_completion_tokens() 253 | 254 | def estimate_cost_usd(self) -> float: 255 | """ 256 | Estimates the cost of the request in USD. Assumes that the model will 257 | return the maximum number of tokens allowed by the max_tokens parameter. 258 | The estimate is the upper bound on the cost, since the model may return 259 | fewer tokens than the maximum allowed. 260 | """ 261 | 262 | input_cost_usd = ( 263 | self.count_prompt_tokens() * self.model.input_token_price_per_1k / 1000 264 | ) 265 | output_cost_usd = ( 266 | self.count_completion_tokens() * self.model.output_token_price_per_1k / 1000 267 | ) 268 | 269 | return input_cost_usd + output_cost_usd 270 | 271 | 272 | def build_binpacked_requests( 273 | model: Model, 274 | function: FunctionDef, 275 | system_message: str, 276 | texts: List[str], 277 | params: Parameters, 278 | max_tokens_per_request: Optional[int] = None, 279 | max_texts_per_request: Optional[int] = None, 280 | binpacking_function: Callable = utils.binpack_texts_in_order, 281 | formatter_function: Callable = utils.format_texts_as_json, 282 | encoding_name: str = "cl100k_base", 283 | long_text_handling: str = "error", 284 | ) -> List[ChatCompletionRequest]: 285 | """ 286 | Builds a list of ChatCompletionRequests from a list of texts. 287 | If possible, multiple texts will be combined into a single ChatCompletionRequest. 288 | This can reduce the number of tokens spent on overheads like the system message 289 | and function definition. 290 | 291 | The requests can then be passed to processor.process_api_requests(). 292 | 293 | Args: 294 | model: The model to use for completion. 295 | function: The function definition to use for the assistant's response. 296 | Must be a dictionary that describes a valid JSON schema. 297 | See https://platform.openai.com/docs/guides/gpt/function-calling 298 | system_message: The message to include at the beginning of each chat. 299 | texts: A list of texts to binpack into chats. Duplicates are not allowed. 300 | params: Object of class Parameters. See models.Parameters for details. 301 | max_tokens_per_request: The maximum number of tokens allowed in one request. 302 | Defaults to 90% of the model's context size. The 10% buffer makes 303 | sure that mistakes in token counting don't cause the request to fail. 304 | max_texts_per_request: The maximum number of texts allowed in one request. 305 | Defaults to None, which means there is no limit. 306 | binpacking_function: The function to use for binpacking. 307 | Must take a list of texts and return a list of lists of texts. 308 | Defaults to binpack_texts_in_order(). 309 | formatter_function: The function to use for formatting the texts. 310 | Must take a list of texts and return a single string. 311 | Defaults to format_texts_as_json(). 312 | encoding_name: The name of the encoding to use for tokenization. 313 | Defaults to "cl100k_base". 314 | long_text_handling: Passed to the binpacking function. Defaults to 315 | "error", which means that an error will be raised if a text is too 316 | long to fit in a single chat. 317 | 318 | Returns: 319 | A list of ChatCompletionRequests. 320 | """ 321 | if len(set(texts)) != len(texts): 322 | # Downstream code assumes that each request has a unique hash 323 | # Duplicate texts would cause the requests to have the same hash 324 | # Plus it's probably a mistake and would waste money 325 | raise ValueError("Duplicate texts found. Please remove duplicates.") 326 | 327 | if max_tokens_per_request is None: 328 | max_tokens_per_request = int(model.context_size * 0.9) 329 | 330 | # System message and function definition count towards the token limit 331 | overheads = [system_message, json.dumps(function)] 332 | static_tokens = sum([utils.num_tokens_from_text(text) for text in overheads]) 333 | 334 | # Calculate the maximum number of tokens left for the chat, 335 | # after accounting for the overheads and the output tokens 336 | max_tokens_per_chat = max_tokens_per_request - static_tokens - params.max_tokens 337 | 338 | # Binpack the texts into chats 339 | text_bins = binpacking_function( 340 | texts=texts, 341 | max_tokens_per_bin=max_tokens_per_chat, 342 | max_texts_per_bin=max_texts_per_request, 343 | encoding_name=encoding_name, 344 | formatter_function=formatter_function, 345 | long_text_handling=long_text_handling, 346 | ) 347 | 348 | requests = [] 349 | 350 | for text_bin in text_bins: 351 | # Create a chat from the bin 352 | messages = [ChatMessage("system", system_message)] 353 | messages.append(ChatMessage("user", formatter_function(text_bin))) 354 | 355 | chat = Chat(messages) 356 | 357 | request = ChatCompletionRequest( 358 | chat=chat, 359 | model=model, 360 | function=function, 361 | params=params, 362 | ) 363 | 364 | requests.append(request) 365 | 366 | return requests 367 | 368 | 369 | def build_requests( 370 | model: Model, 371 | function: FunctionDef, 372 | system_message: str, 373 | texts: List[str], 374 | params: Parameters, 375 | encoding_name: str = "cl100k_base", 376 | long_text_handling: str = "error", 377 | ) -> List[ChatCompletionRequest]: 378 | """ 379 | Builds a list of ChatCompletionRequests from a list of texts. 380 | The requests can then be passed to processor.process_api_requests(). 381 | 382 | Args: 383 | model: The model to use for completion. 384 | function: The function definition to use for the assistant's response. 385 | Must be a dictionary that describes a valid JSON schema. 386 | See https://platform.openai.com/docs/guides/gpt/function-calling 387 | system_message: The message to include at the beginning of each chat. 388 | params: Object of class Parameters. See models.Parameters for details. 389 | texts: A list of texts to binpack into chats. Duplicates are not allowed. 390 | encoding_name: The name of the encoding to use for tokenization. 391 | Defaults to "cl100k_base". 392 | long_text_handling: Passed to the binpacking function. Defaults to 393 | "error", which means that an error will be raised if a text is too 394 | long to fit in a single chat. 395 | 396 | Returns: 397 | A list of ChatCompletionRequests. 398 | """ 399 | 400 | return build_binpacked_requests( 401 | model=model, 402 | function=function, 403 | system_message=system_message, 404 | texts=texts, 405 | params=params, 406 | max_tokens_per_request=None, 407 | max_texts_per_request=1, 408 | binpacking_function=utils.binpack_texts_in_order, 409 | formatter_function=utils.format_texts_with_spaces, 410 | encoding_name=encoding_name, 411 | long_text_handling=long_text_handling, 412 | ) 413 | -------------------------------------------------------------------------------- /texttunnel/processor.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # This file includes classes and functions adapted from: openai-cookbook 3 | # Original source code: https://github.com/openai/openai-cookbook/blob/c651bfdda64ac049747c2a174cde1c946e2baf1d/examples/api_request_parallel_processor.py 4 | # Copyright (c) 2023 OpenAI 5 | 6 | # MIT License 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | # imports 27 | import asyncio # for running API calls concurrently 28 | import json # for saving results to a jsonl file 29 | import logging # for logging rate limit warnings and other messages 30 | import os # for reading API key from environment variable 31 | import sys # for checking notebook vs. script 32 | import tempfile # for creating a temporary file to save results 33 | import time # for sleeping after rate limit is hit 34 | from dataclasses import dataclass, field 35 | from pathlib import Path # for saving results to a file 36 | from typing import Any, Dict, Generator, List, Optional, Union # for type hints 37 | 38 | import aiohttp 39 | import aiohttp_client_cache 40 | 41 | # for storing API inputs, outputs, and metadata 42 | import jsonschema # for validating API responses 43 | 44 | from texttunnel.chat import ChatCompletionRequest 45 | from texttunnel.models import Model 46 | from texttunnel.utils import hash_dict 47 | 48 | Response = List[Dict[str, Any]] 49 | 50 | logger = logging.getLogger("texttunnel") 51 | 52 | 53 | def prepare_output_filepath( 54 | output_filepath: Optional[Union[str, Path]], keep_file: bool 55 | ) -> Path: 56 | """ 57 | Validates the output_filepath and returns a Path object. Uses a temporary file 58 | if output_filepath is None. 59 | 60 | Args: 61 | output_filepath: The path to save the results to. If None, a temporary file 62 | will be used. 63 | keep_file: Whether to keep the file after the function returns. If True, 64 | output_filepath must not be None. 65 | 66 | Returns: 67 | A Path object representing the output_filepath. 68 | """ 69 | using_tempfile = False 70 | 71 | if output_filepath is None: 72 | output_filepath = tempfile.NamedTemporaryFile(delete=False).name 73 | using_tempfile = True 74 | if keep_file: 75 | raise ValueError( 76 | "keep_file=True is not compatible with output_filepath=None" 77 | ) 78 | 79 | if not isinstance(output_filepath, Path): 80 | output_filepath = Path(output_filepath) 81 | 82 | if output_filepath.exists() and not using_tempfile: 83 | raise ValueError(f"File already exists: {output_filepath}") 84 | 85 | return output_filepath 86 | 87 | 88 | def process_api_requests( 89 | requests: List[ChatCompletionRequest], 90 | output_filepath: Optional[Union[str, Path]] = None, 91 | keep_file: bool = False, 92 | max_attempts: int = 10, 93 | rate_limit_headroom_factor: float = 0.75, 94 | api_key: Optional[str] = None, 95 | cache: Optional[aiohttp_client_cache.CacheBackend] = None, 96 | ) -> List[Response]: 97 | """ 98 | Make requests to OpenAI. This function is a wrapper around 99 | aprocess_api_requests() that executes it within asyncio.run, saving you the 100 | trouble of having to use asyncio directly. 101 | 102 | Note that if you're running this function in a Jupyter notebook, the function 103 | will automatically import nest_asyncio and call nest_asyncio.apply() to allow 104 | a second event loop to run in the same process. This is necessary because 105 | Jupyter notebooks already run an event loop in the background. 106 | 107 | If you require more control over the event loop, use the coroutine 108 | aprocess_api_requests() instead. 109 | 110 | Args: 111 | requests: List[ChatCompletionRequest] 112 | The requests to process, see ChatCompletionRequest class for details. 113 | Duplicate requests are not allowed. 114 | output_filepath: str, optional 115 | Path to the file where the results will be saved 116 | file will be a jsonl file, where each line is an array with the original 117 | request plus the API response e.g., 118 | [{"model": "gpt-4", "messages": "..."}, {...}] 119 | if omitted, the results will be saved to a temporary file. 120 | keep_file: bool, optional 121 | Whether to keep the results file after the script finishes, in addition 122 | to the results being returned by the function. 123 | Defaults to False, so the file will be deleted after the script finishes. 124 | Setting this to True is not compatible with output_filepath=None. 125 | max_attempts: int, optional 126 | Number of times to retry a failed request before giving up 127 | if omitted, will default to 5 128 | rate_limit_headroom_factor: float, optional 129 | Factor to multiply the rate limit by to guarantee that the script 130 | stays under the limit if omitted, will default to 0.75 131 | (75% of the rate limit) 132 | api_key: str, optional 133 | API key to use. If omitted, the function will attempt to read it 134 | from an environment variable OPENAI_API_KEY. If that fails, an error 135 | will be raised, unless all requests are cached. 136 | cache: aiohttp_client_cache.CacheBackend, optional 137 | If provided, API responses will be served from the cache if available. 138 | New responses will be saved to the cache. 139 | Check the aiohttp_client_cache documentation for details on the 140 | available cache backends and how to configure them. See 141 | https://aiohttp-client-cache.readthedocs.io/en/stable/backends.html. 142 | Each backend requires different dependencies. For example, the SQLite 143 | backend requires the package "aiosqlite" to be installed. 144 | 145 | Returns: 146 | List[Dict[str, Any]]: list where each element consists of two dictionaries: 147 | - the original request 148 | - the API response 149 | """ 150 | 151 | # Handle Notebook environment 152 | if "ipykernel" in sys.modules: 153 | # nest_asyncio is a workaround for running asyncio in Jupyter notebooks 154 | # it's always available when ipykernel is available 155 | import nest_asyncio 156 | 157 | nest_asyncio.apply() 158 | logger.info( 159 | "Running in Jupyter notebook environment. Activated nest_asyncio to allow asyncio to run." 160 | ) 161 | 162 | responses = asyncio.run( 163 | aprocess_api_requests( 164 | requests=requests, 165 | output_filepath=output_filepath, 166 | keep_file=keep_file, 167 | max_attempts=max_attempts, 168 | rate_limit_headroom_factor=rate_limit_headroom_factor, 169 | api_key=api_key, 170 | cache=cache, 171 | ) 172 | ) 173 | 174 | return responses 175 | 176 | 177 | async def fetch_json_response_from_cache( 178 | cache: aiohttp_client_cache.CacheBackend, url: str, request_json: dict 179 | ) -> Optional[dict]: 180 | """ 181 | Fetch a response from the cache if it exists. 182 | 183 | Args: 184 | cache: Cache to fetch from. 185 | url: URL that was requested. 186 | request_json: JSON payload that was sent with the request. 187 | 188 | Returns: 189 | The cached response JSON if it exists, otherwise None. 190 | """ 191 | cache_return_tuple = await cache.request( 192 | method="POST", # ChatCompletion always uses POST requests 193 | url=url, 194 | json=request_json, 195 | ) 196 | 197 | if cache_return_tuple[0] is None: 198 | return None 199 | 200 | cache_response_json = await cache_return_tuple[0].json() 201 | 202 | return cache_response_json 203 | 204 | 205 | async def aprocess_api_requests( 206 | requests: List[ChatCompletionRequest], 207 | output_filepath: Optional[Union[str, Path]] = None, 208 | keep_file: bool = False, 209 | max_attempts: int = 10, 210 | rate_limit_headroom_factor: float = 0.75, 211 | api_key: Optional[str] = None, 212 | cache: Optional[aiohttp_client_cache.CacheBackend] = None, 213 | ) -> List[Response]: 214 | """ 215 | Make asynchronous requests to the OpenAI API while 216 | throttling to stay under rate limits. 217 | 218 | Features: 219 | - Makes requests concurrently, to maximize throughput 220 | - Throttles request and token usage, to stay under rate limits 221 | - Retries failed requests up to {max_attempts} times, to avoid missing data 222 | - Logs errors, to diagnose problems with requests 223 | 224 | 225 | Args: 226 | requests: List[ChatCompletionRequest] 227 | The requests to process, see ChatCompletionRequest class for details. 228 | Duplicate requests are not allowed. 229 | output_filepath: str, optional 230 | Path to the file where the results will be saved 231 | file will be a jsonl file, where each line is an array with the original 232 | request plus the API response e.g., 233 | [{"model": "gpt-4", "messages": "..."}, {...}] 234 | if omitted, the results will be saved to a temporary file. 235 | keep_file: bool, optional 236 | Whether to keep the results file after the script finishes, in addition 237 | to the results being returned by the function. 238 | Defaults to False, so the file will be deleted after the script finishes. 239 | Setting this to True is not compatible with output_filepath=None. 240 | max_attempts: int, optional 241 | Number of times to retry a failed request before giving up 242 | if omitted, defaults to 5. 243 | rate_limit_headroom_factor: float, optional 244 | Factor to multiply the rate limit by to guarantee that the script 245 | stays under the limit if omitted, defaults to 0.75. 246 | (75% of the rate limit). 247 | api_key: str, optional 248 | API key to use. If omitted, the function will attempt to read it 249 | from an environment variable OPENAI_API_KEY. If that fails, an error 250 | will be raised, unless all requests are cached. 251 | cache: aiohttp_client_cache.CacheBackend, optional 252 | If provided, API responses will be served from the cache if available. 253 | New responses will be saved to the cache. 254 | Check the aiohttp_client_cache documentation for details on the 255 | available cache backends and how to configure them. See 256 | https://aiohttp-client-cache.readthedocs.io/en/stable/backends.html. 257 | Each backend has different dependencies. For example, the SQLite 258 | backend requires the package "aiosqlite" to be installed. 259 | 260 | Returns: 261 | List[Dict[str, Any]]: list where each element consists of two dictionaries: 262 | - the original request 263 | - the API response 264 | """ 265 | 266 | if len(requests) != len(set([request.get_hash() for request in requests])): 267 | # Duplicate requests can cause problems with ordering of results 268 | # Plus it's probably a mistake and would waste money 269 | raise ValueError("Duplicate requests detected. Each request must be unique.") 270 | 271 | # This function was adapted from openai-cookbook 272 | 273 | # The function is structured as follows: 274 | # - Initialize things 275 | # - In API processing loop 276 | # - Get next request if one is not already waiting for capacity 277 | # - Update available token & request capacity 278 | # - If enough capacity available, call API. Responses are written to file 279 | # - The loop pauses if a rate limit error is hit 280 | # - The loop breaks when no tasks remain 281 | # - Fetch results from file 282 | # - Sort results in order of input requests 283 | # - Return results 284 | 285 | output_filepath = prepare_output_filepath(output_filepath, keep_file) 286 | 287 | # Remember the order of the requests so that we can sort the results 288 | # Duplicate requests are not allowed, so the hash of each request is unique 289 | request_order = {request.get_hash(): i for i, request in enumerate(requests)} 290 | 291 | request_url = "https://api.openai.com/v1/chat/completions" 292 | 293 | if cache: 294 | check_cache_settings(cache) 295 | 296 | # Check if requests can be served from the cache 297 | # Build a list of requests that need to be sent to the API 298 | # Handling cached requests separately allows us to avoid allocating 299 | # rate limit capacity to them and provide clearer logging. 300 | logger.debug("Checking cache for requests.") 301 | 302 | # Make asynchronous calls to the cache 303 | tasks = [ 304 | fetch_json_response_from_cache( 305 | cache=cache, 306 | url=request_url, 307 | request_json=request.to_dict(), 308 | ) 309 | for request in requests 310 | ] 311 | 312 | logger.debug("Created cache request tasks.") 313 | 314 | cached_responses = await asyncio.gather(*tasks) 315 | 316 | logger.debug("Gathered cached responses.") 317 | 318 | # Create a list of requests that need to be sent to the API 319 | requests_queue = [] 320 | 321 | # Check cache responses, and add to queue if not found 322 | for request, response in zip(requests, cached_responses): 323 | if response is not None: 324 | # Add to results file 325 | data = [request.to_dict(), response] 326 | append_to_jsonl(data, output_filepath) 327 | else: 328 | requests_queue.append(request) 329 | 330 | request_cache_hits = len(requests) - len(requests_queue) 331 | logger.info( 332 | f"Found {request_cache_hits} out of {len(requests)} requests in cache." 333 | ) 334 | else: 335 | logger.debug("No cache provided.") 336 | requests_queue = requests.copy() 337 | 338 | logger.debug("Cache check complete.") 339 | 340 | if len(requests_queue) > 0: 341 | await run_request_loop( 342 | requests_queue=requests_queue, 343 | request_url=request_url, 344 | output_filepath=output_filepath, 345 | cache=cache, 346 | max_attempts=max_attempts, 347 | rate_limit_headroom_factor=rate_limit_headroom_factor, 348 | api_key=api_key, 349 | ) 350 | 351 | if cache: 352 | await cache.close() 353 | 354 | with open(output_filepath, "r") as f: 355 | request_response_pairs = [json.loads(line) for line in f] 356 | 357 | # Sort results in order of input requests 358 | # Results is a list of lists, where each sublist is [request, response] 359 | request_response_pairs = sorted( 360 | request_response_pairs, 361 | key=lambda x: request_order[hash_dict(x[0])], 362 | ) 363 | 364 | assert len(request_response_pairs) == len(requests) 365 | 366 | if not keep_file: 367 | output_filepath.unlink() 368 | else: 369 | # Overwrite file with sorted results 370 | with open(output_filepath, "w") as f: 371 | for r in request_response_pairs: 372 | f.write(json.dumps(r) + "\n") 373 | 374 | return request_response_pairs 375 | 376 | 377 | async def run_request_loop( 378 | requests_queue: List[ChatCompletionRequest], 379 | request_url: str, 380 | output_filepath: Path, 381 | cache: Optional[aiohttp_client_cache.CacheBackend] = None, 382 | max_attempts: int = 10, 383 | rate_limit_headroom_factor: float = 0.75, 384 | api_key: Optional[str] = None, 385 | ): 386 | """ 387 | Run the main loop that processes API requests. Save results to a file. 388 | 389 | Args: 390 | requests_queue: A queue of requests to process. 391 | request_url: The URL to send the requests to. 392 | output_filepath: The path to the file where the results will be saved. 393 | cache: A aiohttp_client_cache.CacheBackend object that stores API 394 | responses. If provided, the response will be stored in the cache. 395 | max_attempts: Number of times to retry a failed request before giving up. 396 | rate_limit_headroom_factor: Factor to multiply the rate limit by to 397 | guarantee that the script stays under the limit. 398 | api_key: API key to use. If omitted, the function will attempt to read it 399 | from an environment variable OPENAI_API_KEY. If that fails, an error 400 | will be raised, unless all requests are cached. 401 | 402 | """ 403 | 404 | # Check that all requests use the same model. Otherwise, we can't set 405 | # a single rate limit for all requests. 406 | if len(set([request.model.name for request in requests_queue])) > 1: 407 | raise ValueError("All requests must use the same model.") 408 | 409 | if rate_limit_headroom_factor < 0.01 or rate_limit_headroom_factor > 1: 410 | raise ValueError("rate_limit_headroom_factor must be between 0.01 and 1.") 411 | 412 | # initialize API constants 413 | seconds_to_pause_after_rate_limit_error = 15 414 | seconds_to_sleep_each_loop = ( 415 | 0.001 # 1 ms limits max throughput to 1,000 requests per second 416 | ) 417 | 418 | # initialize API authentication 419 | if api_key is None: 420 | api_key = fetch_api_key() 421 | 422 | request_header = {"Authorization": f"Bearer {api_key}"} 423 | 424 | # initialize trackers 425 | retry_queue = asyncio.Queue() 426 | task_id_generator = ( 427 | task_id_generator_function() 428 | ) # generates integer IDs of 1, 2, 3, ... 429 | status_tracker = ( 430 | StatusTracker() 431 | ) # single instance to track a collection of variables 432 | next_request = None # variable to hold the next request to call 433 | 434 | # initialize available capacity counts 435 | max_requests_per_minute = ( 436 | requests_queue[0].model.requests_per_minute * rate_limit_headroom_factor 437 | ) 438 | max_tokens_per_minute = ( 439 | requests_queue[0].model.tokens_per_minute * rate_limit_headroom_factor 440 | ) 441 | 442 | available_request_capacity = max_requests_per_minute 443 | available_token_capacity = max_tokens_per_minute 444 | last_update_time = time.time() 445 | 446 | logger.debug("Initialization complete.") 447 | 448 | logger.info( 449 | f"Beginning main requests loop. {len(requests_queue)} requests to make." 450 | ) 451 | 452 | # Main loop that runs until all tasks are finished 453 | last_status_log_timestamp = time.time() 454 | 455 | while True: 456 | # get next request if one is not already waiting for capacity 457 | if next_request is None: 458 | # retry a request if one is waiting in the retry queue 459 | if not retry_queue.empty(): 460 | next_request = retry_queue.get_nowait() 461 | logger.debug(f"Retrying request {next_request.task_id}: {next_request}") 462 | 463 | # send a new request if one is waiting in the requests queue 464 | elif len(requests_queue) > 0: 465 | next_chat_completion = requests_queue.pop(0) 466 | 467 | # get new request 468 | next_request = APIRequest( 469 | task_id=next(task_id_generator), 470 | request=next_chat_completion, 471 | token_consumption=next_chat_completion.count_total_tokens(), 472 | attempts_left=max_attempts, 473 | ) 474 | status_tracker.num_tasks_started += 1 475 | status_tracker.num_tasks_in_progress += 1 476 | logger.debug(f"Reading request {next_request.task_id}: {next_request}") 477 | 478 | # update available capacity 479 | current_time = time.time() 480 | seconds_since_update = current_time - last_update_time 481 | available_request_capacity = min( 482 | available_request_capacity 483 | + max_requests_per_minute * seconds_since_update / 60.0, 484 | max_requests_per_minute, 485 | ) 486 | available_token_capacity = min( 487 | available_token_capacity 488 | + max_tokens_per_minute * seconds_since_update / 60.0, 489 | max_tokens_per_minute, 490 | ) 491 | last_update_time = current_time 492 | 493 | # if enough capacity available, call API 494 | if next_request: 495 | if ( 496 | available_request_capacity >= 1 497 | and available_token_capacity >= next_request.token_consumption 498 | ): 499 | # update counters 500 | available_request_capacity -= 1 501 | available_token_capacity -= next_request.token_consumption 502 | next_request.attempts_left -= 1 503 | 504 | # call API 505 | asyncio.create_task( 506 | next_request.call_api( 507 | request_url=request_url, 508 | request_header=request_header, 509 | retry_queue=retry_queue, 510 | output_filepath=output_filepath, 511 | status_tracker=status_tracker, 512 | cache=cache, 513 | ) 514 | ) 515 | next_request = None # reset next_request to empty 516 | 517 | # if all tasks are finished, break 518 | if status_tracker.num_tasks_in_progress == 0: 519 | break 520 | else: 521 | # Log status every 10 seconds 522 | if time.time() - last_status_log_timestamp > 10: 523 | logger.debug( 524 | "%s tasks in progress. Successful tasks: %s. Failed tasks: %s. " 525 | "Rate limit errors: %s. Other errors: %s. Retry queue length: %s. " 526 | "Tasks not yet tried: %s. ", 527 | status_tracker.num_tasks_in_progress, 528 | status_tracker.num_tasks_succeeded, 529 | status_tracker.num_tasks_failed, 530 | status_tracker.num_rate_limit_errors, 531 | status_tracker.num_other_errors, 532 | retry_queue.qsize(), 533 | len(requests_queue), 534 | ) 535 | last_status_log_timestamp = time.time() 536 | 537 | # main loop sleeps briefly so concurrent tasks can run 538 | await asyncio.sleep(seconds_to_sleep_each_loop) 539 | 540 | # if a rate limit error was hit recently, pause to cool down 541 | seconds_since_rate_limit_error = ( 542 | time.time() - status_tracker.time_of_last_rate_limit_error 543 | ) 544 | if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: 545 | remaining_seconds_to_pause = ( 546 | seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error 547 | ) 548 | await asyncio.sleep(remaining_seconds_to_pause) 549 | # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago 550 | logger.warn( 551 | f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}" 552 | ) 553 | 554 | # after finishing, log final status 555 | logger.info("Parallel processing complete.") 556 | if status_tracker.num_tasks_failed > 0: 557 | logger.warning( 558 | f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {output_filepath}." 559 | ) 560 | if status_tracker.num_rate_limit_errors > 0: 561 | logger.warning( 562 | f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate." 563 | ) 564 | 565 | 566 | def check_cache_settings(cache: aiohttp_client_cache.CacheBackend) -> None: 567 | """ 568 | Check that the cache is configured correctly to work with texttunnel. 569 | Raises a ValueError if the cache is not configured correctly. 570 | 571 | Args: 572 | cache: The cache to check. 573 | """ 574 | if "POST" not in cache.allowed_methods: 575 | raise ValueError( 576 | 'cache.allowed_methods must include "POST". Add the argument "allowed_methods=["POST"]" to the cache constructor.' 577 | ) 578 | 579 | if cache.include_headers: 580 | raise ValueError("cache.include_headers must be False to protect the API key.") 581 | 582 | 583 | def fetch_api_key() -> str: 584 | """ 585 | Fetch the API key from the environment variable OPENAI_API_KEY. Raises a 586 | ValueError if the API key is not found. 587 | 588 | Returns: 589 | The API key. 590 | """ 591 | 592 | try: 593 | api_key = os.getenv("OPENAI_API_KEY") 594 | assert api_key is not None 595 | return api_key 596 | except AssertionError: 597 | raise ValueError( 598 | "OPENAI_API_KEY environment variable not found. Please set it and try again." 599 | ) 600 | 601 | 602 | # dataclasses 603 | @dataclass 604 | class StatusTracker: 605 | """Stores metadata about the script's progress. Only one instance is created.""" 606 | 607 | # This class was adapted from openai-cookbook 608 | 609 | num_tasks_started: int = 0 610 | num_tasks_in_progress: int = 0 # script ends when this reaches 0 611 | num_tasks_succeeded: int = 0 612 | num_tasks_failed: int = 0 613 | num_rate_limit_errors: int = 0 614 | num_api_errors: int = 0 # excluding rate limit errors, counted above 615 | num_other_errors: int = 0 616 | time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits 617 | 618 | 619 | @dataclass 620 | class APIRequest: 621 | """Stores an API request's inputs, outputs, and other metadata. 622 | Contains a method to make an API call.""" 623 | 624 | task_id: int 625 | request: ChatCompletionRequest 626 | token_consumption: int 627 | attempts_left: int 628 | result: list = field(default_factory=list) 629 | 630 | # This class was adapted from openai-cookbook 631 | 632 | async def call_api( 633 | self, 634 | request_url: str, 635 | request_header: dict, 636 | retry_queue: asyncio.Queue, 637 | output_filepath: Path, 638 | status_tracker: StatusTracker, 639 | cache: Optional[aiohttp_client_cache.CacheBackend] = None, 640 | timeout_seconds: int = 120, 641 | ): 642 | """ 643 | Calls the OpenAI API and appends the request and result to a JSONL file. 644 | If a cache provided, the result will be stored in the cache. 645 | The cache key is the hash of the request. 646 | 647 | Args: 648 | request_url: The URL to send the request to. 649 | request_header: The header to send with the request. 650 | retry_queue: A queue of requests that need to be retried. 651 | Will be populated if the request fails. 652 | output_filepath: The path to the file where the results will be saved. 653 | status_tracker: A StatusTracker object that tracks the greater 654 | request loop's progress. 655 | cache: A aiohttp_client_cache.CacheBackend object that stores API 656 | responses. If provided, the response will be stored in the cache. 657 | timeout_seconds: The number of seconds to wait for a response before 658 | timing out. Defaults to 120 seconds. 659 | """ 660 | 661 | error = None 662 | 663 | logger.info(f"Starting request #{self.task_id}") 664 | timeout = aiohttp.ClientTimeout(total=timeout_seconds) 665 | 666 | # Choose the session class based on whether cache is provided 667 | session_class = ( 668 | aiohttp.ClientSession 669 | if cache is None 670 | else aiohttp_client_cache.CachedSession 671 | ) 672 | 673 | session_kwargs = ( 674 | {"timeout": timeout} 675 | if cache is None 676 | else {"cache": cache, "timeout": timeout} 677 | ) 678 | 679 | try: 680 | async with session_class(**session_kwargs) as session: 681 | async with session.post( 682 | url=request_url, 683 | headers=request_header, 684 | json=self.request.to_dict(), 685 | ) as response: 686 | response = await response.json() 687 | 688 | if "error" in response: 689 | # API and rate limit errors don't raise an exception 690 | # They are found in the response JSON 691 | logger.warning( 692 | f"Request {self.task_id} failed with error {response['error']}" 693 | ) 694 | 695 | error = response 696 | if "Rate limit" in response["error"].get("message", ""): 697 | status_tracker.time_of_last_rate_limit_error = int(time.time()) 698 | status_tracker.num_rate_limit_errors += 1 699 | else: 700 | status_tracker.num_api_errors += 1 701 | 702 | except ( 703 | Exception 704 | ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them 705 | logger.warning(f"Request {self.task_id} failed with Exception {e}") 706 | status_tracker.num_other_errors += 1 707 | error = e 708 | 709 | if error: 710 | self.result.append(error) 711 | if self.attempts_left: 712 | retry_queue.put_nowait(self) 713 | logger.debug( 714 | "Added request #%s to retry queue. Queue length: %s.", 715 | self.task_id, 716 | retry_queue.qsize(), 717 | ) 718 | else: 719 | logger.error( 720 | f"Request {self.request.to_dict()} failed after all attempts. Saving errors: {self.result}" 721 | ) 722 | data = [self.request.to_dict(), [str(e) for e in self.result]] 723 | append_to_jsonl(data, output_filepath) 724 | status_tracker.num_tasks_in_progress -= 1 725 | status_tracker.num_tasks_failed += 1 726 | else: # success 727 | data = [self.request.to_dict(), response] 728 | append_to_jsonl(data, output_filepath) 729 | status_tracker.num_tasks_in_progress -= 1 730 | status_tracker.num_tasks_succeeded += 1 731 | logger.debug(f"Request #{self.task_id} saved to {output_filepath}") 732 | 733 | 734 | # functions 735 | def append_to_jsonl(data: Any, filename: Path) -> None: 736 | """ 737 | Append a json payload to the end of a jsonl file. 738 | 739 | Args: 740 | data: The data to append. 741 | filename: The file to append to. 742 | """ 743 | # This function was adapted from openai-cookbook 744 | 745 | json_string = json.dumps(data) 746 | with open(filename, "a") as f: 747 | f.write(json_string + "\n") 748 | 749 | 750 | def task_id_generator_function() -> Generator[int, None, None]: 751 | """ 752 | Generate integers 0, 1, 2, and so on. 753 | 754 | Returns: 755 | A generator that yields integers. 756 | """ 757 | # This function was adapted from openai-cookbook 758 | 759 | task_id = 0 760 | while True: 761 | yield task_id 762 | task_id += 1 763 | 764 | 765 | RESPONSE_SCHEMA = { 766 | "$schema": "http://json-schema.org/draft-07/schema#", 767 | "type": "array", 768 | "items": [ 769 | # Request schema 770 | { 771 | "type": "object", 772 | "properties": { 773 | "model": {"type": "string"}, 774 | "max_tokens": {"type": "integer"}, 775 | "messages": { 776 | "type": "array", 777 | "items": { 778 | "type": "object", 779 | "properties": { 780 | "role": {"type": "string"}, 781 | "content": {"type": "string"}, 782 | }, 783 | "required": ["role", "content"], 784 | }, 785 | }, 786 | "functions": { 787 | "type": "array", 788 | "properties": { 789 | "name": {"type": "string"}, 790 | "parameters": {"type": "object"}, 791 | }, 792 | "required": ["name", "parameters"], 793 | }, 794 | }, 795 | "required": [ 796 | "model", 797 | "max_tokens", 798 | "messages", 799 | "functions", 800 | ], 801 | }, 802 | # Response schema 803 | { 804 | "type": "object", 805 | "properties": { 806 | "id": {"type": "string"}, 807 | "object": {"type": "string"}, 808 | "created": {"type": "integer"}, 809 | "model": {"type": "string"}, 810 | "choices": { 811 | "type": "array", 812 | "items": { 813 | "type": "object", 814 | "properties": { 815 | "index": {"type": "integer"}, 816 | "message": { 817 | "type": "object", 818 | "properties": { 819 | "role": {"type": "string"}, 820 | "content": {"type": ["string", "null"]}, 821 | "function_call": { 822 | "type": "object", 823 | "properties": { 824 | "name": {"type": "string"}, 825 | "arguments": {"type": "string"}, 826 | }, 827 | "required": ["name", "arguments"], 828 | }, 829 | }, 830 | "required": ["role", "function_call"], 831 | }, 832 | "finish_reason": {"type": "string"}, 833 | }, 834 | "required": ["index", "message", "finish_reason"], 835 | }, 836 | }, 837 | "usage": { 838 | "type": "object", 839 | "properties": { 840 | "prompt_tokens": {"type": "integer"}, 841 | "completion_tokens": {"type": "integer"}, 842 | "total_tokens": {"type": "integer"}, 843 | }, 844 | "required": ["prompt_tokens", "completion_tokens", "total_tokens"], 845 | }, 846 | }, 847 | "required": ["id", "object", "created", "model", "choices", "usage"], 848 | }, 849 | ], 850 | } 851 | 852 | 853 | def is_valid_response(response: Response, print_errors=False) -> bool: 854 | """ 855 | Check if a response conforms to the response JSON schema. 856 | """ 857 | try: 858 | jsonschema.validate(response, RESPONSE_SCHEMA) 859 | return True 860 | except jsonschema.exceptions.ValidationError as e: 861 | if print_errors: 862 | print(e) 863 | return False 864 | 865 | 866 | def parse_arguments(response: Response) -> Dict[str, Any]: 867 | """ 868 | Extract the function call arguments from a response. 869 | 870 | Args: 871 | response: The response to parse. It should be a list of length 2, where the 872 | first element is the request and the second element is the response. 873 | 874 | Returns: 875 | The function call arguments. 876 | """ 877 | 878 | if not is_valid_response(response): 879 | raise ValueError("Response is not valid.") 880 | 881 | return json.loads( 882 | response[1]["choices"][0]["message"]["function_call"]["arguments"] 883 | ) 884 | 885 | 886 | def parse_token_usage(response: Response) -> Dict[str, Any]: 887 | """ 888 | Extract the token usage from a response. 889 | 890 | Args: 891 | response: The response to parse. It should be a list of length 2, where the 892 | first element is the request and the second element is the response. 893 | 894 | Returns: 895 | The token usage. 896 | """ 897 | if not is_valid_response(response): 898 | raise ValueError("Response is not valid.") 899 | 900 | return response[1]["usage"] 901 | 902 | 903 | def usage_to_cost(usage: Dict, model: Model): 904 | """ 905 | Convert token usage to cost in USD. 906 | 907 | Args: 908 | usage: The token usage. Retrieve it with parse_token_usage(). 909 | model: The model used to generate the response. 910 | 911 | Returns: 912 | The cost in USD. 913 | """ 914 | input_cost = model.input_token_price_per_1k * usage["prompt_tokens"] / 1000 915 | output_cost = model.output_token_price_per_1k * usage["completion_tokens"] / 1000 916 | total_cost = input_cost + output_cost 917 | return total_cost 918 | --------------------------------------------------------------------------------