├── .github └── workflows │ ├── quality.yaml │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── example.py ├── py_txi ├── __init__.py ├── inference_server.py ├── text_embedding_inference.py ├── text_generation_inference.py └── utils.py ├── pyproject.toml ├── setup.py └── tests └── test_txi.py /.github/workflows/quality.yaml: -------------------------------------------------------------------------------- 1 | name: quality 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | check_quality: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: "3.10" 26 | 27 | - name: Install quality requirements 28 | run: | 29 | pip install --upgrade pip 30 | pip install -e .[quality] 31 | 32 | - name: Check quality 33 | run: | 34 | make quality 35 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | release: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout code 12 | uses: actions/checkout@v4 13 | 14 | - name: Set up Python 3.10 15 | uses: actions/setup-python@v5 16 | with: 17 | python-version: "3.10" 18 | 19 | - name: Install release requirements 20 | run: | 21 | pip install --upgrade pip 22 | pip install setuptools wheel twine 23 | 24 | - name: Build and publish release 25 | env: 26 | TWINE_USERNAME: __token__ 27 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 28 | run: | 29 | python setup.py sdist bdist_wheel 30 | twine upload --repository pypi dist/* 31 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | run_tests: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: "3.10" 26 | 27 | - name: Install testing requirements 28 | run: | 29 | pip install --upgrade pip 30 | pip install -e .[testing] 31 | 32 | - name: Run test 33 | run: | 34 | make test_cpu 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # List of targets that are not associated with files 2 | .PHONY: quality style install test 3 | 4 | quality: 5 | ruff check . 6 | ruff format --check . 7 | 8 | style: 9 | ruff format . 10 | ruff check --fix . 11 | 12 | test_cpu: 13 | pytest tests/ -s -x -k "cpu" 14 | 15 | test_gpu: 16 | pytest tests/ -s -x -k "gpu" 17 | 18 | install: 19 | pip install -e . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Py-TXI 2 | 3 | [![PyPI version](https://badge.fury.io/py/py-txi.svg)](https://badge.fury.io/py/py-txi) 4 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/py-txi)](https://pypi.org/project/py-txi/) 5 | [![PyPI - Format](https://img.shields.io/pypi/format/py-txi)](https://pypi.org/project/py-txi/) 6 | [![Downloads](https://pepy.tech/badge/py-txi)](https://pepy.tech/project/py-txi) 7 | [![PyPI - License](https://img.shields.io/pypi/l/py-txi)](https://pypi.org/project/py-txi/) 8 | [![Test](https://github.com/IlyasMoutawwakil/py-txi/actions/workflows/test.yaml/badge.svg)](https://github.com/IlyasMoutawwakil/py-txi/actions/workflows/tests.yaml) 9 | 10 | Py-TXI is a Python wrapper around [Text-Generation-Inference](https://github.com/huggingface/text-generation-inference) and [Text-Embedding-Inference](https://github.com/huggingface/text-embeddings-inference) that enables creating and running TGI/TEI instances through the awesome `docker-py` in a similar style to Transformers API. 11 | 12 | ## Installation 13 | 14 | ```bash 15 | pip install py-txi 16 | ``` 17 | 18 | Py-TXI is designed to be used in a similar way to Transformers API. We use `docker-py` (instead of a dirty `subprocess` solution) so that the containers you run are linked to the main process and are stopped automatically when your code finishes or fails. 19 | 20 | ## Advantages 21 | 22 | - **Easy to use**: Py-TXI is designed to be used in a similar way to Transformers API. 23 | - **Automatic cleanup**: Py-TXI stops the Docker container when your code finishes or fails. 24 | - **Batched inference**: Py-TXI supports sending a batch of inputs to the server for inference. 25 | - **Automatic port allocation**: Py-TXI automatically allocates a free port for the Inference server. 26 | - **Configurable**: Py-TXI allows you to configure the Inference servers using a simple configuration object. 27 | - **Verbose**: Py-TXI streams the logs of the underlying Docker container to the main process so you can debug easily. 28 | 29 | ## Usage 30 | 31 | Here's an example of how to use it: 32 | 33 | ```python 34 | from py_txi import TGI, TGIConfig 35 | 36 | llm = TGI(config=TGIConfig(model_id="bigscience/bloom-560m", gpus="0")) 37 | output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"]) 38 | print("LLM:", output) 39 | llm.close() 40 | ``` 41 | 42 | Output: ```LLM: [' student. I have a problem with the following code. I have a class that has a method that', '"\n\n"I\'m fine," said the girl, "but I don\'t want to be alone.']``` 43 | 44 | ```python 45 | from py_txi import TEI, TEIConfig 46 | 47 | embed = TEI(config=TEIConfig(model_id="BAAI/bge-base-en-v1.5")) 48 | output = embed.encode(["Hi, I'm an embedding model", "I'm fine, how are you?"]) 49 | print("Embed:", output) 50 | embed.close() 51 | ``` 52 | 53 | Output: ```[array([[ 0.01058742, -0.01588806, -0.03487622, ..., -0.01613717, 54 | 0.01772875, -0.02237891]], dtype=float32), array([[ 0.02815401, -0.02892136, -0.0536355 , ..., 0.01225784, 55 | -0.00241452, -0.02836569]], dtype=float32)]``` 56 | 57 | That's it! Now you can write your Python scripts using the power of TGI and TEI without having to worry about the underlying Docker containers. 58 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from py_txi.text_embedding_inference import TEI, TEIConfig 2 | from py_txi.text_generation_inference import TGI, TGIConfig 3 | 4 | for gpus in [None, "1", "1,2"]: 5 | llm = TGI(config=TGIConfig(model_id="gpt2", gpus=gpus)) 6 | output = llm.generate(["Hi, I'm a language model", "I'm fine, how are you?"]) 7 | print(len(output)) 8 | print("LLM:", output) 9 | llm.close() 10 | 11 | embed = TEI(config=TEIConfig(model_id="BAAI/bge-base-en-v1.5", gpus=gpus)) 12 | output = embed.encode(["Hi, I'm an embedding model", "I'm fine, how are you?"]) 13 | print(len(output)) 14 | print("Embed:", output) 15 | embed.close() 16 | -------------------------------------------------------------------------------- /py_txi/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_embedding_inference import TEI, TEIConfig # noqa 2 | from .text_generation_inference import TGI, TGIConfig # noqa 3 | from .utils import is_nvidia_system, is_rocm_system, get_free_port # noqa 4 | -------------------------------------------------------------------------------- /py_txi/inference_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | import re 5 | import time 6 | from abc import ABC 7 | from dataclasses import asdict, dataclass, field 8 | from logging import getLogger 9 | from typing import Any, Dict, List, Optional, Union 10 | 11 | import docker 12 | import docker.errors 13 | import docker.types 14 | from huggingface_hub import AsyncInferenceClient 15 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE 16 | 17 | from .utils import get_free_port, styled_logs 18 | 19 | DOCKER = docker.from_env() 20 | LOGGER = getLogger("Inference-Server") 21 | logging.basicConfig(level=logging.INFO) 22 | 23 | 24 | @dataclass 25 | class InferenceServerConfig: 26 | # Common options 27 | model_id: Optional[str] = None 28 | revision: Optional[str] = None 29 | 30 | # Image to use for the container 31 | image: Optional[str] = None 32 | # Shared memory size for the container 33 | shm_size: Optional[str] = None 34 | # List of custom devices to forward to the container e.g. ["/dev/kfd", "/dev/dri"] for ROCm 35 | devices: Optional[List[str]] = None 36 | # NVIDIA-docker GPU device options e.g. "all" (all) or "0,1,2,3" (ids) or 4 (count) 37 | gpus: Optional[Union[str, int]] = None 38 | ports: Dict[str, Any] = field( 39 | default_factory=lambda: {"80/tcp": ("0.0.0.0", 0)}, 40 | metadata={"help": "Dictionary of ports to expose from the container."}, 41 | ) 42 | volumes: Dict[str, Any] = field( 43 | default_factory=lambda: {HUGGINGFACE_HUB_CACHE: {"bind": "/data", "mode": "rw"}}, 44 | metadata={"help": "Dictionary of volumes to mount inside the container."}, 45 | ) 46 | environment: List[str] = field( 47 | default_factory=lambda: ["HF_TOKEN", "HF_API_TOKEN", "HF_HUB_TOKEN"], 48 | metadata={"help": "List of environment variables to forward to the container."}, 49 | ) 50 | 51 | # first connection/request 52 | connection_timeout: int = 60 53 | first_request_timeout: int = 60 54 | 55 | def __post_init__(self) -> None: 56 | if self.ports["80/tcp"][1] == 0: 57 | LOGGER.info("\t+ Getting a free port for the server") 58 | self.ports["80/tcp"] = (self.ports["80/tcp"][0], get_free_port()) 59 | LOGGER.info(f"\t+ Using port {self.ports['80/tcp'][0]}:{self.ports['80/tcp'][1]} for the server") 60 | 61 | if self.shm_size is None: 62 | LOGGER.warning("\t+ Shared memory size not provided. Defaulting to '1g'.") 63 | self.shm_size = "1g" 64 | 65 | 66 | class InferenceServer(ABC): 67 | NAME: str = "Inference-Server" 68 | SUCCESS_SENTINEL: str = "Success" 69 | FAILURE_SENTINEL: str = "Failure" 70 | 71 | def __init__(self, config: InferenceServerConfig) -> None: 72 | self.config = config 73 | 74 | try: 75 | LOGGER.info("\t+ Checking if server image is available locally") 76 | DOCKER.images.get(self.config.image) 77 | LOGGER.info("\t+ Server image found locally") 78 | except docker.errors.ImageNotFound: 79 | LOGGER.info("\t+ Server image not found locally, pulling it from Docker Hub") 80 | DOCKER.images.pull(self.config.image) 81 | 82 | if self.config.gpus is not None and isinstance(self.config.gpus, str) and self.config.gpus == "all": 83 | LOGGER.info("\t+ Using all GPU(s)") 84 | self.device_requests = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] 85 | elif self.config.gpus is not None and isinstance(self.config.gpus, int): 86 | LOGGER.info(f"\t+ Using {self.config.gpus} GPU(s)") 87 | self.device_requests = [docker.types.DeviceRequest(count=self.config.gpus, capabilities=[["gpu"]])] 88 | elif ( 89 | self.config.gpus is not None 90 | and isinstance(self.config.gpus, str) 91 | and re.match(r"^\d+(,\d+)*$", self.config.gpus) 92 | ): 93 | LOGGER.info(f"\t+ Using GPU(s) {self.config.gpus}") 94 | self.device_requests = [docker.types.DeviceRequest(device_ids=[self.config.gpus], capabilities=[["gpu"]])] 95 | else: 96 | LOGGER.info("\t+ Not using any GPU(s)") 97 | self.device_requests = None 98 | 99 | self.command = [] 100 | LOGGER.info("\t+ Building launcher command") 101 | if self.config.model_id is not None: 102 | self.command.append(f"--model-id={self.config.model_id}") 103 | if self.config.revision is not None: 104 | self.command.append(f"--revision={self.config.revision}") 105 | 106 | for k, v in asdict(self.config).items(): 107 | if k in InferenceServerConfig.__annotations__: 108 | continue 109 | if v is not None: 110 | if isinstance(v, bool): 111 | self.command.append(f"--{k.replace('_', '-')}") 112 | else: 113 | self.command.append(f"--{k.replace('_', '-')}={str(v).lower()}") 114 | 115 | self.command.append("--json-output") 116 | 117 | self.environment = {} 118 | LOGGER.info("\t+ Building server environment") 119 | for key in self.config.environment: 120 | if key in os.environ: 121 | LOGGER.info(f"\t+ Forwarding environment variable {key}") 122 | self.environment[key] = os.environ[key] 123 | 124 | LOGGER.info("\t+ Running server container") 125 | self.container = DOCKER.containers.run( 126 | image=self.config.image, 127 | ports=self.config.ports, 128 | volumes=self.config.volumes, 129 | devices=self.config.devices, 130 | shm_size=self.config.shm_size, 131 | environment=self.environment, 132 | device_requests=self.device_requests, 133 | command=self.command, 134 | auto_remove=True, 135 | detach=True, 136 | ) 137 | 138 | LOGGER.info("\t+ Streaming server logs") 139 | for line in self.container.logs(stream=True): 140 | log = line.decode("utf-8").strip() 141 | log = styled_logs(log) 142 | 143 | if self.SUCCESS_SENTINEL.lower() in log.lower(): 144 | LOGGER.info(f"\t+ {log}") 145 | break 146 | elif self.FAILURE_SENTINEL.lower() in log.lower(): 147 | raise Exception(f"Server failed to start with failure message: {log}") 148 | else: 149 | LOGGER.info(f"\t+ {log}") 150 | 151 | address, port = self.config.ports["80/tcp"] 152 | self.url = f"http://{address}:{port}" 153 | 154 | try: 155 | asyncio.set_event_loop(asyncio.get_event_loop()) 156 | except RuntimeError: 157 | asyncio.set_event_loop(asyncio.new_event_loop()) 158 | 159 | self.semaphore = asyncio.Semaphore(self.config.max_concurrent_requests) 160 | 161 | start_time = time.time() 162 | LOGGER.info("\t+ Trying to connect to server") 163 | while time.time() - start_time < self.config.connection_timeout: 164 | try: 165 | self.client = AsyncInferenceClient(model=self.url) 166 | LOGGER.info("\t+ Connected to server successfully") 167 | break 168 | except Exception: 169 | LOGGER.info("\t+ Failed to connect to server, waiting for 1 second and retrying") 170 | time.sleep(1) 171 | 172 | start_time = time.time() 173 | LOGGER.info("\t+ Trying to run a first request") 174 | while time.time() - start_time < self.config.first_request_timeout: 175 | try: 176 | asyncio.run(self.single_client_call("Hello server!")) 177 | LOGGER.info("\t+ Ran first request successfully") 178 | break 179 | except Exception: 180 | LOGGER.info("\t+ Failed to run first request, waiting for 1 second and retrying") 181 | time.sleep(1) 182 | 183 | async def single_client_call(self, *args, **kwargs) -> Any: 184 | raise NotImplementedError 185 | 186 | async def batch_client_call(self, *args, **kwargs) -> Any: 187 | raise NotImplementedError 188 | 189 | def close(self) -> None: 190 | if hasattr(self, "container"): 191 | container = DOCKER.containers.get(self.container.id) 192 | if container.status == "running": 193 | LOGGER.info("\t+ Stoping Docker container") 194 | container.stop() 195 | container.wait() 196 | LOGGER.info("\t+ Docker container stopped") 197 | del self.container 198 | 199 | if hasattr(self, "semaphore"): 200 | if self.semaphore.locked(): 201 | self.semaphore.release() 202 | del self.semaphore 203 | 204 | if hasattr(self, "client"): 205 | del self.client 206 | 207 | def __del__(self) -> None: 208 | self.close() 209 | -------------------------------------------------------------------------------- /py_txi/text_embedding_inference.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | from logging import getLogger 4 | from typing import List, Literal, Optional, Union 5 | 6 | import numpy as np 7 | 8 | from .inference_server import InferenceServer, InferenceServerConfig 9 | from .utils import is_nvidia_system 10 | 11 | LOGGER = getLogger("Text-Embedding-Inference") 12 | 13 | 14 | Pooling_Literal = Literal["cls", "mean"] 15 | DType_Literal = Literal["float32", "float16"] 16 | 17 | 18 | @dataclass 19 | class TEIConfig(InferenceServerConfig): 20 | # Launcher options 21 | dtype: Optional[DType_Literal] = None 22 | pooling: Optional[Pooling_Literal] = None 23 | 24 | # Concurrency options 25 | max_concurrent_requests: int = 512 26 | 27 | def __post_init__(self) -> None: 28 | super().__post_init__() 29 | 30 | if self.image is None: 31 | if is_nvidia_system() and self.gpus is not None: 32 | LOGGER.info("\t+ Using image version cuda-latest for Text-Embedding-Inference") 33 | self.image = "ghcr.io/huggingface/text-embeddings-inference:cuda-latest" 34 | else: 35 | LOGGER.info("\t+ Using image version cpu-1.4 for Text-Embedding-Inference") 36 | self.image = "ghcr.io/huggingface/text-embeddings-inference:cpu-1.4" 37 | 38 | if is_nvidia_system() and "cpu" in self.image: 39 | LOGGER.warning("\t+ You are running on a NVIDIA GPU system but using a CPU image.") 40 | 41 | if self.pooling is None: 42 | LOGGER.warning("\t+ Pooling strategy not provided. Defaulting to 'cls' pooling.") 43 | self.pooling = "cls" 44 | 45 | 46 | class TEI(InferenceServer): 47 | NAME: str = "Text-Embedding-Inference" 48 | SUCCESS_SENTINEL: str = "Ready" 49 | FAILURE_SENTINEL: str = "Error" 50 | 51 | def __init__(self, config: TEIConfig) -> None: 52 | super().__init__(config) 53 | 54 | async def single_client_call(self, text: str, **kwargs) -> np.ndarray: 55 | async with self.semaphore: 56 | output = await self.client.feature_extraction(text=text, **kwargs) 57 | return output 58 | 59 | async def batch_client_call(self, text: List[str], **kwargs) -> List[np.ndarray]: 60 | output = await asyncio.gather(*[self.single_client_call(t, **kwargs) for t in text]) 61 | return output 62 | 63 | def encode(self, text: Union[str, List[str]], **kwargs) -> Union[np.ndarray, List[np.ndarray]]: 64 | if isinstance(text, str): 65 | output = asyncio.run(self.single_client_call(text, **kwargs)) 66 | return output 67 | elif isinstance(text, list): 68 | output = asyncio.run(self.batch_client_call(text, **kwargs)) 69 | return output 70 | else: 71 | raise ValueError(f"Unsupported input type: {type(text)}") 72 | -------------------------------------------------------------------------------- /py_txi/text_generation_inference.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | from logging import getLogger 4 | from typing import Literal, Optional, Union 5 | 6 | from .inference_server import InferenceServer, InferenceServerConfig 7 | from .utils import is_nvidia_system, is_rocm_system 8 | 9 | LOGGER = getLogger("Text-Generation-Inference") 10 | 11 | Shareded_Literal = Literal["true", "false"] 12 | DType_Literal = Literal["float32", "float16", "bfloat16"] 13 | Quantize_Literal = Literal["bitsandbytes-nf4", "bitsandbytes-fp4", "gptq", "awq", "eetq", "fp8"] 14 | 15 | 16 | @dataclass 17 | class TGIConfig(InferenceServerConfig): 18 | # Launcher options 19 | num_shard: Optional[int] = None 20 | speculate: Optional[int] = None 21 | cuda_graphs: Optional[int] = None 22 | dtype: Optional[DType_Literal] = None 23 | trust_remote_code: Optional[bool] = None 24 | sharded: Optional[Shareded_Literal] = None 25 | quantize: Optional[Quantize_Literal] = None 26 | disable_custom_kernels: Optional[bool] = None 27 | 28 | # Concurrency options 29 | max_concurrent_requests: int = 128 30 | 31 | def __post_init__(self) -> None: 32 | super().__post_init__() 33 | 34 | if self.image is None: 35 | if is_nvidia_system() and self.gpus is not None: 36 | LOGGER.info("\t+ Using image version 3.2.3 for Text-Generation-Inference") 37 | self.image = "ghcr.io/huggingface/text-generation-inference:3.2.3" 38 | elif is_rocm_system() and self.devices is not None: 39 | LOGGER.info("\t+ Using image version 3.2.3-rocm for Text-Generation-Inference") 40 | self.image = "ghcr.io/huggingface/text-generation-inference:3.2.3-rocm" 41 | else: 42 | LOGGER.info("\t+ Using image version 3.2.3-intel-cpu for Text-Generation-Inference") 43 | self.image = "ghcr.io/huggingface/text-generation-inference:3.2.3-intel-cpu" 44 | 45 | 46 | class TGI(InferenceServer): 47 | NAME: str = "Text-Generation-Inference" 48 | SUCCESS_SENTINEL: str = "Connected" 49 | FAILURE_SENTINEL: str = "Traceback" 50 | 51 | def __init__(self, config: TGIConfig) -> None: 52 | super().__init__(config) 53 | 54 | async def single_client_call(self, prompt: str, **kwargs) -> str: 55 | async with self.semaphore: 56 | output = await self.client.text_generation(prompt=prompt, **kwargs) 57 | return output 58 | 59 | async def batch_client_call(self, prompt: list, **kwargs) -> list: 60 | output = await asyncio.gather(*[self.single_client_call(prompt=p, **kwargs) for p in prompt]) 61 | return output 62 | 63 | def generate(self, prompt: Union[str, list], **kwargs) -> Union[str, list]: 64 | if isinstance(prompt, str): 65 | output = asyncio.run(self.single_client_call(prompt, **kwargs)) 66 | return output 67 | elif isinstance(prompt, list): 68 | output = asyncio.run(self.batch_client_call(prompt, **kwargs)) 69 | return output 70 | else: 71 | raise ValueError(f"Unsupported input type: {type(prompt)}") 72 | -------------------------------------------------------------------------------- /py_txi/utils.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import subprocess 3 | from datetime import datetime 4 | from json import loads 5 | 6 | 7 | def get_free_port() -> int: 8 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 9 | s.bind(("", 0)) 10 | return s.getsockname()[1] 11 | 12 | 13 | def is_rocm_system() -> bool: 14 | try: 15 | subprocess.check_output(["rocm-smi"]) 16 | return True 17 | except FileNotFoundError: 18 | return False 19 | 20 | 21 | def is_nvidia_system() -> bool: 22 | try: 23 | subprocess.check_output(["nvidia-smi"]) 24 | return True 25 | except FileNotFoundError: 26 | return False 27 | 28 | 29 | LEVEL_TO_MESSAGE_STYLE = { 30 | "DEBUG": "\033[37m", 31 | "INFO": "\033[37m", 32 | "WARN": "\033[33m", 33 | "WARNING": "\033[33m", 34 | "ERROR": "\033[31m", 35 | "CRITICAL": "\033[31m", 36 | } 37 | TIMESTAMP_STYLE = "\033[32m" 38 | TARGET_STYLE = "\033[0;38m" 39 | LEVEL_STYLE = "\033[1;30m" 40 | 41 | 42 | def color_text(text: str, color: str) -> str: 43 | return f"{color}{text}\033[0m" 44 | 45 | 46 | def styled_logs(log: str) -> str: 47 | try: 48 | dict_log = loads(log) 49 | except Exception: 50 | return log 51 | 52 | fields = dict_log.get("fields", {}) 53 | level = dict_log.get("level", "could not parse level") 54 | target = dict_log.get("target", "could not parse target") 55 | timestamp = dict_log.get("timestamp", "could not parse timestamp") 56 | message = fields.get("message", dict_log.get("message", "could not parse message")) 57 | timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%Y-%m-%d %H:%M:%S") 58 | 59 | message = color_text(message, LEVEL_TO_MESSAGE_STYLE.get(level, "\033[37m")) 60 | timestamp = color_text(timestamp, TIMESTAMP_STYLE) 61 | target = color_text(target, TARGET_STYLE) 62 | level = color_text(level, LEVEL_STYLE) 63 | 64 | return f"[{timestamp}][{target}][{level}] - {message}" 65 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 120 3 | lint.ignore = ["C901", "E501"] 4 | lint.select = ["C", "E", "F", "I", "W", "I001"] 5 | 6 | [tool.ruff.format] 7 | line-ending = "auto" 8 | quote-style = "double" 9 | indent-style = "space" 10 | skip-magic-trailing-comma = false 11 | 12 | [tool.pytest.ini_options] 13 | log_cli = true 14 | log_cli_level = "INFO" 15 | log_cli_format = "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s" 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | PY_TXI_VERSION = "0.10.0" 6 | 7 | common_setup_kwargs = { 8 | "author": "Ilyas Moutawwakil", 9 | "author_email": "ilyas.moutawwakil@gmail.com", 10 | "description": "A Python wrapper around TGI and TEI servers", 11 | "keywords": ["tgi", "llm", "tei", "embedding", "huggingface", "docker", "python"], 12 | "url": "https://github.com/IlyasMoutawwakil/py-txi", 13 | "long_description_content_type": "text/markdown", 14 | "long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"), 15 | "platforms": ["linux", "windows", "macos"], 16 | "classifiers": [ 17 | "Programming Language :: Python :: 3", 18 | "Natural Language :: English", 19 | ], 20 | } 21 | 22 | 23 | setup( 24 | name="py-txi", 25 | version=PY_TXI_VERSION, 26 | packages=find_packages(), 27 | install_requires=["docker", "huggingface-hub", "numpy", "aiohttp"], 28 | extras_require={"quality": ["ruff"], "testing": ["pytest"]}, 29 | **common_setup_kwargs, 30 | ) 31 | -------------------------------------------------------------------------------- /tests/test_txi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from py_txi import TEI, TGI, TEIConfig, TGIConfig 4 | 5 | 6 | def test_cpu_tei(): 7 | embed = TEI(config=TEIConfig(model_id="BAAI/bge-base-en-v1.5")) 8 | output = embed.encode("Hi, I'm a language model") 9 | assert isinstance(output, np.ndarray) 10 | output = embed.encode(["Hi, I'm a language model", "I'm fine, how are you?"]) 11 | assert isinstance(output, list) and all(isinstance(x, np.ndarray) for x in output) 12 | embed.close() 13 | 14 | 15 | def test_cpu_tgi(): 16 | llm = TGI(config=TGIConfig(model_id="gpt2")) 17 | output = llm.generate("Hi, I'm a sanity test", max_new_tokens=2) 18 | assert isinstance(output, str) 19 | output = llm.generate(["Hi, I'm a sanity test", "I'm a second sentence"], max_new_tokens=2) 20 | assert isinstance(output, list) and all(isinstance(x, str) for x in output) 21 | llm.close() 22 | --------------------------------------------------------------------------------