├── tests ├── __init__.py ├── integ │ ├── __init__.py │ ├── test_pytorch_local_inf2.py │ ├── test_pytorch_remote_cpu.py │ ├── test_pytorch_remote_gpu.py │ ├── test_pytorch_local_gpu.py │ ├── test_pytorch_local_cpu.py │ ├── utils.py │ ├── conftest.py │ ├── config.py │ └── helpers.py ├── unit │ ├── __init__.py │ ├── conftest.py │ ├── test_diffusers.py │ ├── test_const.py │ ├── test_serializer.py │ ├── test_vertex_ai_utils.py │ ├── test_optimum_utils.py │ ├── test_handler.py │ ├── test_sentence_transformers.py │ └── test_utils.py └── resources │ ├── custom_handler │ ├── custom_utils.py │ └── pipeline.py │ ├── audio │ ├── sample.amr │ ├── sample.m4a │ ├── sample1.mp3 │ ├── sample1.ogg │ ├── sample1.wav │ ├── sample1.flac │ ├── sample1.webm │ └── long_sample.mp3 │ └── image │ ├── tiger.bmp │ ├── tiger.gif │ ├── tiger.jpeg │ ├── tiger.png │ ├── tiger.tiff │ └── tiger.webp ├── src └── huggingface_inference_toolkit │ ├── __init__.py │ ├── serialization │ ├── audio_utils.py │ ├── __init__.py │ ├── image_utils.py │ ├── json_utils.py │ └── base.py │ ├── async_utils.py │ ├── const.py │ ├── logging.py │ ├── env_utils.py │ ├── vertex_ai_utils.py │ ├── diffusers_utils.py │ ├── sentence_transformers_utils.py │ ├── optimum_utils.py │ ├── webservice_starlette.py │ ├── handler.py │ └── utils.py ├── MANIFEST.in ├── .dockerignore ├── scripts ├── inf2_entrypoint.sh ├── entrypoint.sh └── inf2_env.py ├── setup.cfg ├── .github └── workflows │ ├── quality.yaml │ ├── build-container.yaml │ ├── unit-test.yaml │ ├── integration-test.yaml │ ├── docker-build-action.yaml │ └── integration-test-action.yaml ├── pyproject.toml ├── Makefile ├── dockerfiles └── pytorch │ ├── Dockerfile │ └── Dockerfile.inf2 ├── setup.py ├── .gitignore ├── LICENSE └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integ/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | 4 | recursive-exclude * __pycache__ 5 | recursive-exclude * *.py[co] 6 | -------------------------------------------------------------------------------- /tests/resources/custom_handler/custom_utils.py: -------------------------------------------------------------------------------- 1 | def test_method(input): 2 | """reverse string""" 3 | return input[::-1] 4 | -------------------------------------------------------------------------------- /tests/resources/audio/sample.amr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/audio/sample.amr -------------------------------------------------------------------------------- /tests/resources/audio/sample.m4a: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/audio/sample.m4a -------------------------------------------------------------------------------- /tests/resources/audio/sample1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/audio/sample1.mp3 -------------------------------------------------------------------------------- /tests/resources/audio/sample1.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/audio/sample1.ogg -------------------------------------------------------------------------------- /tests/resources/audio/sample1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/audio/sample1.wav -------------------------------------------------------------------------------- /tests/resources/image/tiger.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/image/tiger.bmp -------------------------------------------------------------------------------- /tests/resources/image/tiger.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/image/tiger.gif -------------------------------------------------------------------------------- /tests/resources/image/tiger.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/image/tiger.jpeg -------------------------------------------------------------------------------- /tests/resources/image/tiger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/image/tiger.png -------------------------------------------------------------------------------- /tests/resources/image/tiger.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/image/tiger.tiff -------------------------------------------------------------------------------- /tests/resources/image/tiger.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/image/tiger.webp -------------------------------------------------------------------------------- /tests/resources/audio/sample1.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/audio/sample1.flac -------------------------------------------------------------------------------- /tests/resources/audio/sample1.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/audio/sample1.webm -------------------------------------------------------------------------------- /tests/resources/audio/long_sample.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/huggingface-inference-toolkit/HEAD/tests/resources/audio/long_sample.mp3 -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .github 3 | .pytest_cache 4 | .ruff_cache 5 | .tox 6 | .venv 7 | .gitignore 8 | Makefile 9 | __pycache__ 10 | tests 11 | .vscode 12 | -------------------------------------------------------------------------------- /tests/unit/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(scope = "session") 7 | def cache_test_dir(): 8 | yield os.environ.get("CACHE_TEST_DIR", "./tests") 9 | -------------------------------------------------------------------------------- /tests/resources/custom_handler/pipeline.py: -------------------------------------------------------------------------------- 1 | from custom_utils import test_method 2 | 3 | 4 | class PreTrainedPipeline: 5 | def __init__(self, path): 6 | self.path = path 7 | 8 | def __call__(self, data): 9 | res = test_method(data) 10 | return res 11 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/serialization/audio_utils.py: -------------------------------------------------------------------------------- 1 | class Audioer: 2 | @staticmethod 3 | def deserialize(body): 4 | return {"inputs": bytes(body)} 5 | 6 | @staticmethod 7 | def serialize(body, accept=None): 8 | raise NotImplementedError("Audio serialization not implemented") 9 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/serialization/__init__.py: -------------------------------------------------------------------------------- 1 | from huggingface_inference_toolkit.serialization.audio_utils import Audioer # noqa: F401 2 | from huggingface_inference_toolkit.serialization.image_utils import Imager # noqa: F401 3 | from huggingface_inference_toolkit.serialization.json_utils import Jsoner # noqa: F401 4 | -------------------------------------------------------------------------------- /scripts/inf2_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e -o pipefail -u 3 | 4 | export ENV_FILEPATH=$(mktemp) 5 | 6 | trap "rm -f ${ENV_FILEPATH}" EXIT 7 | 8 | touch $ENV_FILEPATH 9 | 10 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 11 | 12 | ${SCRIPT_DIR}/inf2_env.py $@ 13 | 14 | source $ENV_FILEPATH 15 | 16 | rm -f $ENV_FILEPATH 17 | 18 | exec ${SCRIPT_DIR}/entrypoint.sh $@ -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = huggingface_inference_toolkit 7 | known_third_party = 8 | transformers 9 | huggingface_hub 10 | datasets 11 | tensorflow 12 | torch 13 | 14 | line_length = 119 15 | lines_after_imports = 2 16 | multi_line_output = 3 17 | use_parentheses = True 18 | 19 | [flake8] 20 | ignore = E203, E501, E741, W503, W605 21 | max-line-length = 119 22 | -------------------------------------------------------------------------------- /.github/workflows/quality.yaml: -------------------------------------------------------------------------------- 1 | name: Quality Check 2 | 3 | on: 4 | push: 5 | paths-ignore: 6 | - 'README.md' 7 | branches: 8 | - main 9 | pull_request: 10 | workflow_dispatch: 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | quality: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python 3.11 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: 3.11 25 | - name: Install Python dependencies 26 | run: pip install -e .[quality] 27 | - name: Run Quality check 28 | run: make quality -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/serialization/image_utils.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | from PIL import Image 4 | 5 | 6 | class Imager: 7 | @staticmethod 8 | def deserialize(body): 9 | image = Image.open(BytesIO(body)).convert("RGB") 10 | return {"inputs": image} 11 | 12 | @staticmethod 13 | def serialize(image, accept=None): 14 | if isinstance(image, Image.Image): 15 | img_byte_arr = BytesIO() 16 | image.save(img_byte_arr, format=accept.split("/")[-1].upper()) 17 | img_byte_arr = img_byte_arr.getvalue() 18 | return img_byte_arr 19 | else: 20 | raise ValueError(f"Can only serialize PIL.Image.Image, got {type(image)}") 21 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/async_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Callable, Dict, TypeVar 3 | 4 | import anyio 5 | from anyio import Semaphore 6 | from typing_extensions import ParamSpec 7 | 8 | # To not have too many threads running (which could happen on too many concurrent 9 | # requests, we limit it with a semaphore. 10 | MAX_CONCURRENT_THREADS = 1 11 | MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS) 12 | T = TypeVar("T") 13 | P = ParamSpec("P") 14 | 15 | 16 | # moves blocking call to asyncio threadpool limited to 1 to not overload the system 17 | # REF: https://stackoverflow.com/a/70929141 18 | async def async_handler_call(handler: Callable[P, T], body: Dict[str, Any]) -> T: 19 | async with MAX_THREADS_GUARD: 20 | return await anyio.to_thread.run_sync(functools.partial(handler, body)) 21 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/const.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from huggingface_inference_toolkit.env_utils import strtobool 5 | 6 | HF_MODEL_DIR = os.environ.get("HF_MODEL_DIR", "/opt/huggingface/model") 7 | HF_MODEL_ID = os.environ.get("HF_MODEL_ID", None) 8 | HF_TASK = os.environ.get("HF_TASK", None) 9 | HF_FRAMEWORK = os.environ.get("HF_FRAMEWORK", None) 10 | HF_REVISION = os.environ.get("HF_REVISION", None) 11 | HF_HUB_TOKEN = os.environ.get("HF_HUB_TOKEN", None) 12 | HF_TRUST_REMOTE_CODE = strtobool(os.environ.get("HF_TRUST_REMOTE_CODE", "0")) 13 | # custom handler consts 14 | HF_DEFAULT_PIPELINE_NAME = os.environ.get("HF_DEFAULT_PIPELINE_NAME", "handler.py") 15 | # default is pipeline.PreTrainedPipeline 16 | HF_MODULE_NAME = os.environ.get( 17 | "HF_MODULE_NAME", f"{Path(HF_DEFAULT_PIPELINE_NAME).stem}.EndpointHandler" 18 | ) 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.mypy] 2 | ignore_missing_imports = true 3 | no_implicit_optional = true 4 | scripts_are_modules = true 5 | 6 | [tool.ruff] 7 | # Same as Black. 8 | line-length = 119 9 | # Assume Python 3.11 10 | target-version = "py311" 11 | 12 | [tool.ruff.lint] 13 | select = [ 14 | "E", # pycodestyle errors 15 | "W", # pycodestyle warnings 16 | "F", # pyflakes 17 | "I", # isort 18 | "C", # flake8-comprehensions 19 | "B", # flake8-bugbear 20 | ] 21 | ignore = [ 22 | "E501", # Line length (handled by ruff-format) 23 | "B008", # do not perform function calls in argument defaults 24 | "C901", # too complex 25 | ] 26 | # Allow unused variables when underscore-prefixed. 27 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 28 | per-file-ignores = { "__init__.py" = ["F401"] } 29 | 30 | [tool.isort] 31 | profile = "black" 32 | known_third_party = ["transformers", "starlette", "huggingface_hub"] 33 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | 5 | def setup_logging(): 6 | # Remove all existing handlers 7 | for handler in logging.root.handlers[:]: 8 | logging.root.removeHandler(handler) 9 | 10 | # Configure the root logger 11 | logging.basicConfig( 12 | level=logging.INFO, 13 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 14 | datefmt="%Y-%m-%d %H:%M:%S", 15 | stream=sys.stdout, 16 | ) 17 | 18 | # Remove Uvicorn loggers 19 | logging.getLogger("uvicorn").handlers.clear() 20 | logging.getLogger("uvicorn.access").handlers.clear() 21 | logging.getLogger("uvicorn.error").handlers.clear() 22 | 23 | # Create a logger for your application 24 | logger = logging.getLogger("huggingface_inference_toolkit") 25 | return logger 26 | 27 | 28 | # Create and configure the logger 29 | logger = setup_logging() 30 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/env_utils.py: -------------------------------------------------------------------------------- 1 | def strtobool(val: str) -> bool: 2 | """Convert a string representation of truth to True or False booleans. 3 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values 4 | are 'n', 'no', 'f', 'false', 'off', and '0'. 5 | 6 | Raises: 7 | ValueError: if 'val' is anything else. 8 | 9 | Note: 10 | Function `strtobool` copied and adapted from `distutils`, as it's deprecated from Python 3.10 onwards. 11 | 12 | References: 13 | - https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/distutils/util.py#L308-L321 14 | """ 15 | val = val.lower() 16 | if val in ("y", "yes", "t", "true", "on", "1"): 17 | return True 18 | if val in ("n", "no", "f", "false", "off", "0"): 19 | return False 20 | raise ValueError( 21 | f"Invalid truth value, it should be a string but {val} was provided instead." 22 | ) 23 | -------------------------------------------------------------------------------- /tests/integ/test_pytorch_local_inf2.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers.testing_utils import require_torch 3 | 4 | from huggingface_inference_toolkit.optimum_utils import is_optimum_neuron_available 5 | from tests.integ.helpers import verify_task 6 | 7 | require_inferentia = pytest.mark.skipif( 8 | not is_optimum_neuron_available(), 9 | reason="Skipping tests, since optimum neuron is not available or not running on inf2 instances.", 10 | ) 11 | 12 | 13 | class TestPytorchLocal: 14 | @require_torch 15 | @require_inferentia 16 | @pytest.mark.parametrize( 17 | "task", 18 | [ 19 | "feature-extraction", 20 | "fill-mask", 21 | "question-answering", 22 | "text-classification", 23 | "token-classification", 24 | ], 25 | ) 26 | @pytest.mark.parametrize("device", ["inf2"]) 27 | @pytest.mark.parametrize("framework", ["pytorch"]) 28 | @pytest.mark.parametrize("repository_id", [""]) 29 | @pytest.mark.usefixtures("local_container") 30 | def test_pt_container_local_model(self, local_container, task) -> None: 31 | 32 | verify_task(task=task, port=local_container[1]) 33 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style unit-test integ-test inference-pytorch-gpu inference-pytorch-cpu inference-pytorch-inf2 stop-all 2 | 3 | check_dirs := src tests 4 | 5 | # Check that source code meets quality standards 6 | quality: 7 | ruff check $(check_dirs) 8 | 9 | # Format source code automatically 10 | style: 11 | ruff check $(check_dirs) --fix 12 | 13 | # Run unit tests 14 | unit-test: 15 | RUN_SLOW=True python3 -m pytest -s -v tests/unit -n 10 --log-cli-level='ERROR' 16 | 17 | # Run integration tests 18 | integ-test: 19 | python3 -m pytest -s -v tests/integ/ 20 | 21 | # Build Docker image for PyTorch on GPU 22 | inference-pytorch-gpu: 23 | docker build -f dockerfiles/pytorch/Dockerfile -t integration-test-pytorch:gpu . 24 | 25 | # Build Docker image for PyTorch on CPU 26 | inference-pytorch-cpu: 27 | docker build --build-arg="BASE_IMAGE=ubuntu:22.04" -f dockerfiles/pytorch/Dockerfile -t integration-test-pytorch:cpu . 28 | 29 | # Build Docker image for PyTorch on AWS Inferentia2 30 | inference-pytorch-inf2: 31 | docker build -f dockerfiles/pytorch/Dockerfile.inf2 -t integration-test-pytorch:inf2 . 32 | 33 | # Stop all and prune/clean the Docker Containers 34 | stop-all: 35 | docker stop $$(docker ps -a -q) && docker container prune --force 36 | -------------------------------------------------------------------------------- /.github/workflows/build-container.yaml: -------------------------------------------------------------------------------- 1 | name: "Build applications images" 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "src/**" 9 | - "dockerfiles/**" 10 | - "scripts/**" 11 | workflow_dispatch: 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | starlette-pytorch-cpu: 19 | uses: ./.github/workflows/docker-build-action.yaml 20 | with: 21 | image: inference-pytorch-cpu 22 | dockerfile: dockerfiles/pytorch/Dockerfile 23 | build_args: "BASE_IMAGE=ubuntu:22.04" 24 | secrets: 25 | REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} 26 | REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} 27 | starlette-pytorch-gpu: 28 | uses: ./.github/workflows/docker-build-action.yaml 29 | with: 30 | image: inference-pytorch-gpu 31 | dockerfile: dockerfiles/pytorch/Dockerfile 32 | secrets: 33 | REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} 34 | REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} 35 | starlette-pytorch-inf2: 36 | uses: ./.github/workflows/docker-build-action.yaml 37 | with: 38 | image: inference-pytorch-inf2 39 | dockerfile: dockerfiles/pytorch/Dockerfile.inf2 40 | secrets: 41 | REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} 42 | REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} -------------------------------------------------------------------------------- /tests/integ/test_pytorch_remote_cpu.py: -------------------------------------------------------------------------------- 1 | import docker 2 | import pytest 3 | import tenacity 4 | 5 | from tests.integ.helpers import verify_task 6 | 7 | 8 | class TestPytorchRemote: 9 | 10 | @tenacity.retry( 11 | retry=tenacity.retry_if_exception(docker.errors.APIError), 12 | stop=tenacity.stop_after_attempt(5), 13 | reraise=True, 14 | ) 15 | @pytest.mark.parametrize("device", ["cpu"]) 16 | @pytest.mark.parametrize( 17 | "task", 18 | [ 19 | "text-classification", 20 | "zero-shot-classification", 21 | "question-answering", 22 | "fill-mask", 23 | "summarization", 24 | "token-classification", 25 | "translation_xx_to_yy", 26 | "text2text-generation", 27 | "text-generation", 28 | "feature-extraction", 29 | "image-classification", 30 | "automatic-speech-recognition", 31 | "audio-classification", 32 | "object-detection", 33 | "image-segmentation", 34 | "table-question-answering", 35 | "conversational", 36 | "sentence-similarity", 37 | "sentence-embeddings", 38 | "sentence-ranking", 39 | "text-to-image", 40 | ], 41 | ) 42 | @pytest.mark.parametrize("framework", ["pytorch"]) 43 | @pytest.mark.usefixtures("remote_container") 44 | def test_inference_remote(self, remote_container, task, framework, device): 45 | 46 | verify_task(task=task, port=remote_container[1]) 47 | -------------------------------------------------------------------------------- /tests/integ/test_pytorch_remote_gpu.py: -------------------------------------------------------------------------------- 1 | import docker 2 | import pytest 3 | import tenacity 4 | 5 | from tests.integ.helpers import verify_task 6 | 7 | 8 | class TestPytorchRemote: 9 | 10 | @tenacity.retry( 11 | retry=tenacity.retry_if_exception(docker.errors.APIError), 12 | stop=tenacity.stop_after_attempt(5), 13 | reraise=True, 14 | ) 15 | @pytest.mark.parametrize("device", ["gpu"]) 16 | @pytest.mark.parametrize( 17 | "task", 18 | [ 19 | "text-classification", 20 | "zero-shot-classification", 21 | "question-answering", 22 | "fill-mask", 23 | "summarization", 24 | "token-classification", 25 | "translation_xx_to_yy", 26 | "text2text-generation", 27 | "text-generation", 28 | "feature-extraction", 29 | "image-classification", 30 | "automatic-speech-recognition", 31 | "audio-classification", 32 | "object-detection", 33 | "image-segmentation", 34 | "table-question-answering", 35 | "conversational", 36 | "sentence-similarity", 37 | "sentence-embeddings", 38 | "sentence-ranking", 39 | "text-to-image", 40 | ], 41 | ) 42 | @pytest.mark.parametrize("framework", ["pytorch"]) 43 | @pytest.mark.usefixtures("remote_container") 44 | def test_inference_remote(self, remote_container, task, framework, device): 45 | 46 | verify_task(task=task, port=remote_container[1]) 47 | -------------------------------------------------------------------------------- /.github/workflows/unit-test.yaml: -------------------------------------------------------------------------------- 1 | name: Run Unit-Tests 2 | 3 | on: 4 | push: 5 | paths-ignore: 6 | - 'README.md' 7 | branches: 8 | - main 9 | pull_request: 10 | workflow_dispatch: 11 | 12 | env: 13 | ACTIONS_RUNNER_DEBUG: true 14 | ACTIONS_STEP_DEBUG: true 15 | 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | pytorch-unit-test: 22 | runs-on: 23 | group: aws-g4dn-2xlarge-cache 24 | env: 25 | AWS_REGION: us-east-1 26 | CACHE_TEST_DIR: /mnt/hf_cache/hf-inference-toolkit-tests 27 | RUN_SLOW: True 28 | steps: 29 | - uses: actions/checkout@v4.1.1 30 | - name: Copy unit tests to cache mount 31 | run: | 32 | sudo rm -rf ${{ env.CACHE_TEST_DIR }} && \ 33 | sudo mkdir ${{ env.CACHE_TEST_DIR }} && \ 34 | sudo chown -R runner ${{ env.CACHE_TEST_DIR }} 35 | cp -r tests ${{ env.CACHE_TEST_DIR }} 36 | - name: Docker Setup Buildx 37 | uses: docker/setup-buildx-action@v3.0.0 38 | - name: Docker Build 39 | run: make inference-pytorch-gpu 40 | - name: Run unit tests 41 | run: | 42 | docker run \ 43 | -e RUN_SLOW='${{ env.RUN_SLOW }}' \ 44 | --gpus all \ 45 | -e CACHE_TEST_DIR='${{ env.CACHE_TEST_DIR }}' \ 46 | -v ./tests:${{ env.CACHE_TEST_DIR }} \ 47 | --entrypoint /bin/bash \ 48 | integration-test-pytorch:gpu \ 49 | -c "pip install '.[test,st,diffusers,google]' && pytest ${{ env.CACHE_TEST_DIR }}/unit" 50 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/serialization/json_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | 4 | import orjson 5 | from PIL import Image 6 | 7 | 8 | def default(obj): 9 | if isinstance(obj, Image.Image): 10 | with BytesIO() as out: 11 | obj.save(out, format="PNG") 12 | png_string = out.getvalue() 13 | return base64.b64encode(png_string).decode("utf-8") 14 | raise TypeError 15 | 16 | 17 | class Jsoner: 18 | @staticmethod 19 | def deserialize(body): 20 | return orjson.loads(body) 21 | 22 | @staticmethod 23 | def serialize(body, accept=None): 24 | return orjson.dumps(body, option=orjson.OPT_SERIALIZE_NUMPY, default=default) 25 | 26 | 27 | # class _JSONEncoder(json.JSONEncoder): 28 | # """ 29 | # custom `JSONEncoder` to make sure float and int64 ar converted 30 | # """ 31 | 32 | # def default(self, obj): 33 | # if isinstance(obj, np.integer): 34 | # return int(obj) 35 | # elif isinstance(obj, np.floating): 36 | # return float(obj) 37 | # elif isinstance(obj, np.ndarray): 38 | # return obj.tolist() 39 | # elif isinstance(obj, datetime.datetime): 40 | # return obj.__str__() 41 | # elif isinstance(obj, Image.Image): 42 | # with BytesIO() as out: 43 | # obj.save(out, format="PNG") 44 | # png_string = out.getvalue() 45 | # return base64.b64encode(png_string).decode("utf-8") 46 | # else: 47 | # return super(_JSONEncoder, self).default(obj) 48 | -------------------------------------------------------------------------------- /.github/workflows/integration-test.yaml: -------------------------------------------------------------------------------- 1 | name: Run Integration Tests 2 | 3 | on: 4 | push: 5 | paths-ignore: 6 | - 'README.md' 7 | - '.github/workflows/unit-test.yaml' 8 | - '.github/workflows/quality.yaml' 9 | branches: 10 | - main 11 | pull_request: 12 | workflow_dispatch: 13 | 14 | concurrency: 15 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | pytorch-integration-local-gpu: 20 | name: Local Integration Tests - GPU 21 | uses: ./.github/workflows/integration-test-action.yaml 22 | with: 23 | test_path: "tests/integ/test_pytorch_local_gpu.py" 24 | build_img_cmd: "make inference-pytorch-gpu" 25 | test_parallelism: "1" 26 | pytorch-integration-remote-gpu: 27 | name: Remote Integration Tests - GPU 28 | uses: ./.github/workflows/integration-test-action.yaml 29 | with: 30 | test_path: "tests/integ/test_pytorch_remote_gpu.py" 31 | build_img_cmd: "make inference-pytorch-gpu" 32 | pytorch-integration-remote-cpu: 33 | name: Remote Integration Tests - CPU 34 | uses: ./.github/workflows/integration-test-action.yaml 35 | with: 36 | test_path: "tests/integ/test_pytorch_remote_cpu.py" 37 | build_img_cmd: "make inference-pytorch-cpu" 38 | pytorch-integration-local-cpu: 39 | name: Local Integration Tests - CPU 40 | uses: ./.github/workflows/integration-test-action.yaml 41 | with: 42 | test_path: "tests/integ/test_pytorch_local_cpu.py" 43 | build_img_cmd: "make inference-pytorch-cpu" 44 | test_parallelism: "1" 45 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/vertex_ai_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | from huggingface_inference_toolkit.logging import logger 6 | 7 | GCS_URI_PREFIX = "gs://" 8 | 9 | 10 | # copied from https://github.com/googleapis/python-aiplatform/blob/94d838d8cfe1599bc2d706e66080c05108821986/google/cloud/aiplatform/utils/prediction_utils.py#L121 11 | def _load_repository_from_gcs( 12 | artifact_uri: str, target_dir: Union[str, Path] = "/tmp" 13 | ) -> str: 14 | """ 15 | Load files from GCS path to target_dir 16 | """ 17 | from google.cloud import storage 18 | 19 | logger.info(f"Loading model artifacts from {artifact_uri} to {target_dir}") 20 | if isinstance(target_dir, str): 21 | target_dir = Path(target_dir) 22 | 23 | if artifact_uri.startswith(GCS_URI_PREFIX): 24 | matches = re.match(f"{GCS_URI_PREFIX}(.*?)/(.*)", artifact_uri) 25 | bucket_name, prefix = matches.groups() 26 | 27 | gcs_client = storage.Client() 28 | blobs = gcs_client.list_blobs(bucket_name, prefix=prefix) 29 | for blob in blobs: 30 | name_without_prefix = blob.name[len(prefix) :] 31 | name_without_prefix = ( 32 | name_without_prefix[1:] 33 | if name_without_prefix.startswith("/") 34 | else name_without_prefix 35 | ) 36 | file_split = name_without_prefix.split("/") 37 | directory = target_dir / Path(*file_split[0:-1]) 38 | directory.mkdir(parents=True, exist_ok=True) 39 | if name_without_prefix and not name_without_prefix.endswith("/"): 40 | blob.download_to_filename(target_dir / name_without_prefix) 41 | 42 | return str(target_dir.absolute()) 43 | -------------------------------------------------------------------------------- /dockerfiles/pytorch/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=nvidia/cuda:12.1.0-devel-ubuntu22.04 2 | 3 | FROM $BASE_IMAGE AS base 4 | SHELL ["/bin/bash", "-c"] 5 | 6 | LABEL maintainer="Hugging Face" 7 | 8 | ENV DEBIAN_FRONTEND=noninteractive 9 | 10 | WORKDIR /app 11 | 12 | RUN apt-get update && \ 13 | apt-get install software-properties-common -y && \ 14 | add-apt-repository ppa:deadsnakes/ppa && \ 15 | apt-get -y upgrade --only-upgrade systemd openssl cryptsetup && \ 16 | apt-get install -y \ 17 | build-essential \ 18 | bzip2 \ 19 | cmake \ 20 | curl \ 21 | git \ 22 | git-lfs \ 23 | tar \ 24 | gcc \ 25 | g++ \ 26 | libprotobuf-dev \ 27 | protobuf-compiler \ 28 | python3.11 \ 29 | python3.11-dev \ 30 | libsndfile1-dev \ 31 | ffmpeg \ 32 | && apt-get clean autoremove --yes \ 33 | && rm -rf /var/lib/{apt,dpkg,cache,log} 34 | 35 | # Copying only necessary files as filtered by .dockerignore 36 | COPY . . 37 | 38 | # Set Python 3.11 as the default python version 39 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 && \ 40 | ln -sf /usr/bin/python3.11 /usr/bin/python 41 | 42 | # Install pip from source 43 | RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ 44 | python get-pip.py && \ 45 | rm get-pip.py 46 | 47 | # Upgrade pip 48 | RUN pip install --no-cache-dir --upgrade pip 49 | 50 | # Install wheel and setuptools 51 | RUN pip install --no-cache-dir --upgrade pip ".[torch,st,diffusers]" 52 | 53 | # Copy application 54 | COPY src/huggingface_inference_toolkit huggingface_inference_toolkit 55 | COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starlette.py 56 | 57 | # Copy entrypoint and change permissions 58 | COPY --chmod=0755 scripts/entrypoint.sh entrypoint.sh 59 | 60 | ENTRYPOINT ["bash", "-c", "./entrypoint.sh"] 61 | -------------------------------------------------------------------------------- /tests/unit/test_diffusers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tempfile 3 | 4 | from PIL import Image 5 | from transformers.testing_utils import require_torch, slow 6 | 7 | from huggingface_inference_toolkit.diffusers_utils import IEAutoPipelineForText2Image 8 | from huggingface_inference_toolkit.utils import _load_repository_from_hf, get_pipeline 9 | 10 | logging.basicConfig(level="DEBUG") 11 | 12 | @require_torch 13 | def test_get_diffusers_pipeline(): 14 | with tempfile.TemporaryDirectory() as tmpdirname: 15 | storage_dir = _load_repository_from_hf( 16 | "echarlaix/tiny-random-stable-diffusion-xl", 17 | tmpdirname, 18 | framework="pytorch" 19 | ) 20 | pipe = get_pipeline("text-to-image", storage_dir.as_posix()) 21 | assert isinstance(pipe, IEAutoPipelineForText2Image) 22 | 23 | 24 | @slow 25 | @require_torch 26 | def test_pipe_on_gpu(): 27 | with tempfile.TemporaryDirectory() as tmpdirname: 28 | storage_dir = _load_repository_from_hf( 29 | "echarlaix/tiny-random-stable-diffusion-xl", 30 | tmpdirname, 31 | framework="pytorch" 32 | ) 33 | pipe = get_pipeline( 34 | "text-to-image", 35 | storage_dir.as_posix() 36 | ) 37 | logging.error(f"Pipe: {pipe.pipeline}") 38 | assert pipe.pipeline.device.type == "cuda" 39 | 40 | 41 | @require_torch 42 | def test_text_to_image_task(): 43 | with tempfile.TemporaryDirectory() as tmpdirname: 44 | storage_dir = _load_repository_from_hf( 45 | "echarlaix/tiny-random-stable-diffusion-xl", 46 | tmpdirname, 47 | framework="pytorch" 48 | ) 49 | pipe = get_pipeline("text-to-image", storage_dir.as_posix()) 50 | res = pipe("Lets create an embedding") 51 | assert isinstance(res, Image.Image) 52 | -------------------------------------------------------------------------------- /.github/workflows/docker-build-action.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_call: 3 | inputs: 4 | context: 5 | type: string 6 | required: false 7 | default: "./" 8 | repository: 9 | type: string 10 | required: false 11 | default: "registry.internal.huggingface.tech/hf-endpoints" 12 | image: 13 | type: string 14 | required: true 15 | build_args: 16 | type: string 17 | required: false 18 | default: "" 19 | dockerfile: 20 | type: string 21 | required: false 22 | default: "Dockerfile" 23 | secrets: 24 | REGISTRY_USERNAME: 25 | required: true 26 | REGISTRY_PASSWORD: 27 | required: true 28 | 29 | jobs: 30 | buildx: 31 | runs-on: 32 | group: aws-highmemory-32-plus-priv 33 | steps: 34 | - name: Check out 35 | uses: actions/checkout@v3 36 | 37 | - name: Set up Docker Buildx 38 | uses: docker/setup-buildx-action@v2.0.0 39 | with: 40 | install: true 41 | 42 | - name: Login to container registry 43 | uses: docker/login-action@v2.0.0 44 | with: 45 | registry: ${{ inputs.repository }} 46 | username: ${{ secrets.REGISTRY_USERNAME }} 47 | password: ${{ secrets.REGISTRY_PASSWORD }} 48 | 49 | - name: Inject slug/short variables 50 | uses: rlespinasse/github-slug-action@v4 51 | 52 | - name: Build and push image to container registry 53 | uses: docker/build-push-action@v3.0.0 54 | with: 55 | push: true 56 | context: ${{ inputs.context }} 57 | build-args: ${{ inputs.build_args }} 58 | target: base 59 | outputs: type=image,compression=zstd,force-compression=true,push=true 60 | file: ${{ inputs.context }}/${{ inputs.dockerfile }} 61 | tags: ${{ inputs.repository }}/${{ inputs.image }}:sha-${{ env.GITHUB_SHA_SHORT }},${{ inputs.repository }}/${{ inputs.image }}:latest 62 | -------------------------------------------------------------------------------- /tests/unit/test_const.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def test_if_provided(): 4 | pass 5 | # with mock.patch.dict( 6 | # os.environ, 7 | # { 8 | # "HF_MODEL_DIR": "provided", 9 | # "HF_MODEL_ID": "mymodel/test", 10 | # "HF_TASK": "text-classification", 11 | # "HF_DEFAULT_PIPELINE_NAME": "endpoint.py", 12 | # "HF_FRAMEWORK": "tf", 13 | # "HF_REVISION": "12312", 14 | # "HF_HUB_TOKEN": "hf_x", 15 | # }, 16 | # ): 17 | # from huggingface_inference_toolkit.const import ( 18 | # HF_DEFAULT_PIPELINE_NAME, 19 | # HF_FRAMEWORK, 20 | # HF_HUB_TOKEN, 21 | # HF_MODEL_DIR, 22 | # HF_MODEL_ID, 23 | # HF_MODULE_NAME, 24 | # HF_REVISION, 25 | # HF_TASK, 26 | # ) 27 | 28 | # assert HF_MODEL_DIR == "provided" 29 | # assert HF_MODEL_ID == "mymodel/test" 30 | # assert HF_TASK == "text-classification" 31 | # assert HF_DEFAULT_PIPELINE_NAME == "endpoint.py" 32 | # assert HF_MODULE_NAME == "endpoint.PreTrainedPipeline" 33 | # assert HF_FRAMEWORK == "tf" 34 | # assert HF_REVISION == "12312" 35 | # assert HF_HUB_TOKEN == "hf_x" 36 | 37 | 38 | # def test_default(): 39 | # os.environ = {} 40 | # from huggingface_inference_toolkit.const import ( 41 | # HF_DEFAULT_PIPELINE_NAME, 42 | # HF_FRAMEWORK, 43 | # HF_HUB_TOKEN, 44 | # HF_MODEL_DIR, 45 | # HF_MODEL_ID, 46 | # HF_MODULE_NAME, 47 | # HF_REVISION, 48 | # HF_TASK, 49 | # ) 50 | 51 | # assert os.environ == {} 52 | # assert HF_MODEL_DIR == "/opt/huggingface/model" 53 | # assert HF_MODEL_ID is None 54 | # assert HF_TASK is None 55 | # assert HF_DEFAULT_PIPELINE_NAME == "pipeline.py" 56 | # assert HF_MODULE_NAME == "pipeline.PreTrainedPipeline" 57 | # assert HF_FRAMEWORK == "pytorch" 58 | # assert HF_REVISION is None 59 | # assert HF_HUB_TOKEN is None 60 | -------------------------------------------------------------------------------- /tests/unit/test_serializer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | from PIL import Image 6 | 7 | from huggingface_inference_toolkit.serialization import Audioer, Imager, Jsoner 8 | 9 | 10 | def test_json_serialization(): 11 | t = {"res": np.array([2.0]), "text": "I like you.", "float": 1.2} 12 | assert b'{"res":[2.0],"text":"I like you.","float":1.2}' == Jsoner.serialize(t) 13 | 14 | 15 | def test_json_image_serialization(): 16 | t = [ 17 | {"label": "refrigerator", "mask": Image.new("RGB", (60, 30), color="red"), "score": 0.9803440570831299}, 18 | {"label": "LABEL_200", "mask": Image.new("RGB", (60, 30), color="red"), "score": 0.9631735682487488}, 19 | {"label": "cat", "mask": Image.new("RGB", (60, 30), color="red"), "score": 0.9995332956314087}, 20 | ] 21 | Jsoner.serialize(t) 22 | 23 | 24 | def test_image_serialization(): 25 | image = Image.new("RGB", (60, 30), color="red") 26 | Imager.serialize(image, accept="image/png") 27 | 28 | 29 | def test_json_deserialization(): 30 | raw_content = b'{\n\t"inputs": "i like you"\n}' 31 | assert {"inputs": "i like you"} == Jsoner.deserialize(raw_content) 32 | 33 | @pytest.mark.usefixtures('cache_test_dir') 34 | def test_image_deserialization(cache_test_dir): 35 | 36 | image_files_path = f"{cache_test_dir}/resources/image" 37 | 38 | for image_file in os.listdir(image_files_path): 39 | image_bytes = open(os.path.join(image_files_path, image_file), "rb").read() 40 | decoded_data = Imager.deserialize(bytearray(image_bytes)) 41 | 42 | assert isinstance(decoded_data, dict) 43 | assert isinstance(decoded_data["inputs"], Image.Image) 44 | 45 | @pytest.mark.usefixtures('cache_test_dir') 46 | def test_audio_deserialization(cache_test_dir): 47 | 48 | audio_files_path = f"{cache_test_dir}/resources/audio" 49 | 50 | for audio_file in os.listdir(audio_files_path): 51 | audio_bytes = open(os.path.join(audio_files_path, audio_file), "rb").read() 52 | decoded_data = Audioer.deserialize(bytearray(audio_bytes)) 53 | 54 | assert {"inputs": audio_bytes} == decoded_data 55 | -------------------------------------------------------------------------------- /.github/workflows/integration-test-action.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_call: 3 | inputs: 4 | region: 5 | type: string 6 | required: false 7 | default: "us-east-1" 8 | hf_home: 9 | required: false 10 | type: string 11 | default: "/mnt/hf_cache/" 12 | hf_hub_cache: 13 | required: false 14 | type: string 15 | default: "/mnt/hf_cache/hub" 16 | run_slow: 17 | required: false 18 | type: string 19 | default: "True" 20 | test_path: 21 | type: string 22 | required: true 23 | test_parallelism: 24 | type: string 25 | required: false 26 | default: "4" 27 | build_img_cmd: 28 | type: string 29 | required: false 30 | default: "make inference-pytorch-gpu" 31 | log_level: 32 | type: string 33 | required: false 34 | default: "ERROR" 35 | log_format: 36 | type: string 37 | required: false 38 | default: "%(asctime)s %(levelname)s %(module)s:%(lineno)d %(message)s" 39 | runs_on: 40 | type: string 41 | required: false 42 | default: 'aws-g4dn-2xlarge-cache' 43 | 44 | jobs: 45 | pytorch-integration-tests: 46 | runs-on: 47 | group: ${{ inputs.runs_on }} 48 | env: 49 | AWS_REGION: ${{ inputs.region }} 50 | HF_HOME: ${{ inputs.hf_home }} 51 | HF_HUB_CACHE: ${{ inputs.hf_hub_cache }} 52 | RUN_SLOW: ${{ inputs.run_slow }} 53 | steps: 54 | - uses: actions/checkout@v4.1.1 55 | - name: Docker Setup Buildx 56 | uses: docker/setup-buildx-action@v3.0.0 57 | - name: Docker Build 58 | run: ${{ inputs.build_img_cmd }} 59 | - name: Set up Python 3.11 60 | uses: actions/setup-python@v2 61 | with: 62 | python-version: 3.11 63 | - name: Install dependencies 64 | run: pip install ".[torch, test]" 65 | - name: Run local integration tests 66 | run: | 67 | python -m pytest \ 68 | ${{ inputs.test_path }} -n ${{ inputs.test_parallelism }} \ 69 | --log-cli-level='${{ inputs.log_level }}' \ 70 | --log-format='${{ inputs.log_format }}' 71 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/serialization/base.py: -------------------------------------------------------------------------------- 1 | from huggingface_inference_toolkit.serialization.audio_utils import Audioer 2 | from huggingface_inference_toolkit.serialization.image_utils import Imager 3 | from huggingface_inference_toolkit.serialization.json_utils import Jsoner 4 | 5 | content_type_mapping = { 6 | "application/json": Jsoner, 7 | "application/json; charset=UTF-8": Jsoner, 8 | "text/csv": None, 9 | "text/plain": None, 10 | # image types 11 | "image/png": Imager, 12 | "image/jpeg": Imager, 13 | "image/jpg": Imager, 14 | "image/tiff": Imager, 15 | "image/bmp": Imager, 16 | "image/gif": Imager, 17 | "image/webp": Imager, 18 | "image/x-image": Imager, 19 | # audio types 20 | "audio/x-flac": Audioer, 21 | "audio/flac": Audioer, 22 | "audio/mpeg": Audioer, 23 | "audio/x-mpeg-3": Audioer, 24 | "audio/wave": Audioer, 25 | "audio/wav": Audioer, 26 | "audio/x-wav": Audioer, 27 | "audio/ogg": Audioer, 28 | "audio/x-audio": Audioer, 29 | "audio/webm": Audioer, 30 | "audio/webm;codecs=opus": Audioer, 31 | "audio/AMR": Audioer, 32 | "audio/amr": Audioer, 33 | "audio/AMR-WB": Audioer, 34 | "audio/AMR-WB+": Audioer, 35 | "audio/m4a": Audioer, 36 | "audio/x-m4a": Audioer, 37 | } 38 | 39 | 40 | class ContentType: 41 | @staticmethod 42 | def get_deserializer(content_type): 43 | if content_type in content_type_mapping: 44 | return content_type_mapping[content_type] 45 | else: 46 | message = f""" 47 | Content type "{content_type}" not supported. 48 | Supported content types are: 49 | {", ".join(list(content_type_mapping.keys()))} 50 | """ 51 | raise Exception(message) 52 | 53 | @staticmethod 54 | def get_serializer(accept): 55 | if accept in content_type_mapping: 56 | return content_type_mapping[accept] 57 | else: 58 | message = f""" 59 | Accept type "{accept}" not supported. 60 | Supported accept types are: 61 | {", ".join(list(content_type_mapping.keys()))} 62 | """ 63 | raise Exception(message) 64 | -------------------------------------------------------------------------------- /tests/unit/test_vertex_ai_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | def test__load_repository_from_gcs(): 5 | """Tests the `_load_repository_from_gcs` function against a public artifact URI. But the 6 | function is overridden since the client needs to be anonymous temporarily, as we're testing 7 | against a publicly accessible artifact. 8 | 9 | References: 10 | - https://cloud.google.com/storage/docs/public-datasets/era5 11 | - https://console.cloud.google.com/storage/browser/gcp-public-data-arco-era5/raw/date-variable-static/2021/12/31/soil_type?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22)) 12 | """ 13 | 14 | public_artifact_uri = ( 15 | "gs://gcp-public-data-arco-era5/raw/date-variable-static/2021/12/31/soil_type" 16 | ) 17 | 18 | def _load_repository_from_gcs(artifact_uri: str, target_dir: Path) -> str: 19 | """Temporarily override of the `_load_repository_from_gcs` function.""" 20 | import re 21 | 22 | from google.cloud import storage 23 | 24 | from huggingface_inference_toolkit.vertex_ai_utils import GCS_URI_PREFIX 25 | 26 | if isinstance(target_dir, str): 27 | target_dir = Path(target_dir) 28 | 29 | if artifact_uri.startswith(GCS_URI_PREFIX): 30 | matches = re.match(f"{GCS_URI_PREFIX}(.*?)/(.*)", artifact_uri) 31 | bucket_name, prefix = matches.groups() # type: ignore 32 | 33 | gcs_client = storage.Client.create_anonymous_client() 34 | blobs = gcs_client.list_blobs(bucket_name, prefix=prefix) 35 | for blob in blobs: 36 | name_without_prefix = blob.name[len(prefix) :] 37 | name_without_prefix = ( 38 | name_without_prefix[1:] 39 | if name_without_prefix.startswith("/") 40 | else name_without_prefix 41 | ) 42 | file_split = name_without_prefix.split("/") 43 | directory = target_dir / Path(*file_split[0:-1]) 44 | directory.mkdir(parents=True, exist_ok=True) 45 | if name_without_prefix and not name_without_prefix.endswith("/"): 46 | blob.download_to_filename(target_dir / name_without_prefix) 47 | 48 | return str(target_dir.absolute()) 49 | 50 | target_dir = Path.cwd() / "target" 51 | target_dir_path = _load_repository_from_gcs( 52 | artifact_uri=public_artifact_uri, target_dir=target_dir 53 | ) 54 | 55 | assert target_dir == Path(target_dir_path) 56 | assert Path(target_dir_path).exists() 57 | assert (Path(target_dir_path) / "static.nc").exists() 58 | -------------------------------------------------------------------------------- /tests/integ/test_pytorch_local_gpu.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers.testing_utils import require_torch 3 | 4 | from tests.integ.helpers import verify_task 5 | 6 | 7 | class TestPytorchLocal: 8 | 9 | @require_torch 10 | @pytest.mark.parametrize( 11 | "task", 12 | [ 13 | "text-classification", 14 | "zero-shot-classification", 15 | "token-classification", 16 | "question-answering", 17 | "fill-mask", 18 | "summarization", 19 | "translation_xx_to_yy", 20 | "text2text-generation", 21 | "text-generation", 22 | "feature-extraction", 23 | "image-classification", 24 | "automatic-speech-recognition", 25 | "audio-classification", 26 | "object-detection", 27 | "image-segmentation", 28 | "table-question-answering", 29 | "conversational", 30 | "sentence-similarity", 31 | "sentence-embeddings", 32 | "sentence-ranking", 33 | "text-to-image", 34 | ], 35 | ) 36 | @pytest.mark.parametrize("device", ["gpu"]) 37 | @pytest.mark.parametrize("framework", ["pytorch"]) 38 | @pytest.mark.parametrize("repository_id", [""]) 39 | @pytest.mark.usefixtures("local_container") 40 | def test_pt_container_local_model( 41 | self, local_container, task, framework, device 42 | ) -> None: 43 | 44 | verify_task(task=task, port=local_container[1]) 45 | 46 | @require_torch 47 | @pytest.mark.parametrize( 48 | "repository_id", 49 | ["philschmid/custom-handler-test", "philschmid/custom-handler-distilbert"], 50 | ) 51 | @pytest.mark.parametrize("device", ["gpu"]) 52 | @pytest.mark.parametrize("framework", ["pytorch"]) 53 | @pytest.mark.parametrize("task", ["custom"]) 54 | @pytest.mark.usefixtures("local_container") 55 | def test_pt_container_custom_handler( 56 | self, local_container, task, device, repository_id 57 | ) -> None: 58 | 59 | verify_task( 60 | task=task, 61 | port=local_container[1], 62 | ) 63 | 64 | @require_torch 65 | @pytest.mark.parametrize( 66 | "repository_id", 67 | ["philschmid/custom-pipeline-text-classification"], 68 | ) 69 | @pytest.mark.parametrize("device", ["gpu"]) 70 | @pytest.mark.parametrize("framework", ["pytorch"]) 71 | @pytest.mark.parametrize("task", ["custom"]) 72 | @pytest.mark.usefixtures("local_container") 73 | def test_pt_container_legacy_custom_pipeline( 74 | self, local_container, repository_id, device, task 75 | ) -> None: 76 | 77 | verify_task(task=task, port=local_container[1]) 78 | -------------------------------------------------------------------------------- /scripts/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set the default port 4 | PORT=5000 5 | 6 | # Check if AIP_MODE is set and adjust the port for Vertex AI 7 | if [[ ! -z "${AIP_MODE}" ]]; then 8 | PORT=${AIP_HTTP_PORT} 9 | fi 10 | 11 | # Check that only one of HF_MODEL_ID or HF_MODEL_DIR is provided 12 | if [[ ! -z "${HF_MODEL_ID}" && ! -z "${HF_MODEL_DIR}" ]]; then 13 | echo "Error: Both HF_MODEL_ID and HF_MODEL_DIR are set. Please provide only one." 14 | exit 1 15 | elif [[ -z "${HF_MODEL_ID}" && -z "${HF_MODEL_DIR}" ]]; then 16 | echo "Error: Neither HF_MODEL_ID nor HF_MODEL_DIR is set. Please provide one of them." 17 | exit 1 18 | fi 19 | 20 | # If HF_MODEL_ID is provided, download handler.py and requirements.txt if available 21 | if [[ ! -z "${HF_MODEL_ID}" ]]; then 22 | filename=${HF_DEFAULT_PIPELINE_NAME:-handler.py} 23 | revision=${HF_REVISION:-main} 24 | 25 | echo "Downloading $filename for model ${HF_MODEL_ID}" 26 | huggingface-cli download ${HF_MODEL_ID} "$filename" --revision "$revision" --local-dir /tmp 27 | 28 | # Check if handler.py was downloaded successfully 29 | if [ -f "/tmp/$filename" ]; then 30 | echo "$filename downloaded successfully, checking if there's a requirements.txt file..." 31 | rm /tmp/$filename 32 | 33 | # Attempt to download requirements.txt 34 | echo "Downloading requirements.txt for model ${HF_MODEL_ID}" 35 | huggingface-cli download "${HF_MODEL_ID}" requirements.txt --revision "$revision" --local-dir /tmp 36 | 37 | # Check if requirements.txt was downloaded successfully 38 | if [ -f "/tmp/requirements.txt" ]; then 39 | echo "requirements.txt downloaded successfully, now installing the dependencies..." 40 | 41 | # Install dependencies 42 | pip install -r /tmp/requirements.txt --no-cache-dir 43 | rm /tmp/requirements.txt 44 | else 45 | echo "${HF_MODEL_ID} with revision $revision contains a custom handler at $filename but doesn't contain a requirements.txt file, so skipping downloading and installing extra requirements from it." 46 | fi 47 | else 48 | echo "${HF_MODEL_ID} with revision $revision doesn't contain a $filename file, so skipping download." 49 | fi 50 | fi 51 | 52 | # If HF_MODEL_DIR is provided, check for requirements.txt and install dependencies if available 53 | if [[ ! -z "${HF_MODEL_DIR}" ]]; then 54 | # Check if requirements.txt exists and if so install dependencies 55 | if [ -f "${HF_MODEL_DIR}/requirements.txt" ]; then 56 | echo "Installing custom dependencies from ${HF_MODEL_DIR}/requirements.txt" 57 | pip install -r ${HF_MODEL_DIR}/requirements.txt --no-cache-dir 58 | fi 59 | fi 60 | 61 | # Start the server 62 | exec uvicorn webservice_starlette:app --host 0.0.0.0 --port ${PORT} 63 | -------------------------------------------------------------------------------- /dockerfiles/pytorch/Dockerfile.inf2: -------------------------------------------------------------------------------- 1 | # Build based on https://github.com/aws/deep-learning-containers/blob/master/huggingface/pytorch/inference/docker/2.1/py3/sdk2.18.0/Dockerfile.neuronx 2 | FROM ubuntu:22.04 AS base 3 | 4 | LABEL maintainer="Hugging Face" 5 | 6 | ARG NEURONX_COLLECTIVES_LIB_VERSION=2.22.33.0-d2128d1aa 7 | ARG NEURONX_RUNTIME_LIB_VERSION=2.22.19.0-5856c0b42 8 | ARG NEURONX_TOOLS_VERSION=2.19.0.0 9 | 10 | # HF ARGS 11 | ARG OPTIMUM_NEURON_VERSION=0.0.28 12 | 13 | # See http://bugs.python.org/issue19846 14 | ENV LANG=C.UTF-8 15 | ENV LD_LIBRARY_PATH=/opt/aws/neuron/lib:/lib/x86_64-linux-gnu:/opt/conda/lib/:$LD_LIBRARY_PATH 16 | ENV PATH=/opt/aws/neuron/bin:$PATH 17 | 18 | RUN apt-get update \ 19 | && apt-get upgrade -y \ 20 | && apt-get install -y --no-install-recommends software-properties-common \ 21 | && apt-get update \ 22 | && apt-get install -y --no-install-recommends \ 23 | build-essential \ 24 | apt-transport-https \ 25 | ca-certificates \ 26 | cmake \ 27 | curl \ 28 | git \ 29 | jq \ 30 | wget \ 31 | unzip \ 32 | zlib1g-dev \ 33 | gpg-agent 34 | 35 | RUN echo "deb https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list 36 | RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add - 37 | 38 | # Install Neuronx tools 39 | RUN apt-get update \ 40 | && apt-get install -y \ 41 | aws-neuronx-tools=$NEURONX_TOOLS_VERSION \ 42 | aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \ 43 | aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION 44 | 45 | RUN apt-get install -y \ 46 | python3 \ 47 | python3-pip \ 48 | python-is-python3 49 | 50 | RUN rm -rf /var/lib/apt/lists/* \ 51 | && rm -rf /tmp/tmp* \ 52 | && apt-get clean 53 | 54 | RUN pip install --no-cache-dir "protobuf>=3.18.3,<4" setuptools==69.5.1 packaging 55 | 56 | WORKDIR / 57 | 58 | # install Hugging Face libraries and its dependencies 59 | RUN pip install --extra-index-url https://pip.repos.neuron.amazonaws.com --no-cache-dir optimum-neuron[neuronx]==${OPTIMUM_NEURON_VERSION} \ 60 | && pip install --no-deps --no-cache-dir -U torchvision==0.16.* 61 | 62 | # FIXME 63 | RUN pip install --extra-index-url https://pip.repos.neuron.amazonaws.com git+https://github.com/huggingface/optimum-neuron.git@5237fb0ada643ba471f60ed3a5d2eef3b66e8e59 64 | 65 | COPY . . 66 | 67 | RUN pip install --no-cache-dir -U pip ".[st]" 68 | 69 | # copy application 70 | COPY src/huggingface_inference_toolkit huggingface_inference_toolkit 71 | COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starlette.py 72 | 73 | # copy entrypoint and change permissions 74 | COPY --chmod=0755 scripts/entrypoint.sh entrypoint.sh 75 | COPY --chmod=0755 scripts/inf2_env.py inf2_env.py 76 | COPY --chmod=0755 scripts/inf2_entrypoint.sh inf2_entrypoint.sh 77 | 78 | ENTRYPOINT ["bash", "-c", "./inf2_entrypoint.sh"] 79 | -------------------------------------------------------------------------------- /tests/integ/test_pytorch_local_cpu.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tenacity 3 | from transformers.testing_utils import require_torch 4 | 5 | from tests.integ.helpers import verify_task 6 | 7 | 8 | class TestPytorchLocal: 9 | @tenacity.retry( 10 | stop=tenacity.stop_after_attempt(5), 11 | reraise=True, 12 | ) 13 | @require_torch 14 | @pytest.mark.parametrize( 15 | "task", 16 | [ 17 | "text-classification", 18 | "zero-shot-classification", 19 | "token-classification", 20 | "question-answering", 21 | "fill-mask", 22 | "summarization", 23 | "translation_xx_to_yy", 24 | "text2text-generation", 25 | "text-generation", 26 | "feature-extraction", 27 | "image-classification", 28 | "automatic-speech-recognition", 29 | "audio-classification", 30 | "object-detection", 31 | "image-segmentation", 32 | "table-question-answering", 33 | "conversational", 34 | "sentence-similarity", 35 | "sentence-embeddings", 36 | "sentence-ranking", 37 | "text-to-image", 38 | ], 39 | ) 40 | @pytest.mark.parametrize("device", ["cpu"]) 41 | @pytest.mark.parametrize("framework", ["pytorch"]) 42 | @pytest.mark.parametrize("repository_id", [""]) 43 | @pytest.mark.usefixtures("local_container") 44 | def test_pt_container_local_model( 45 | self, local_container, task, framework, device 46 | ) -> None: 47 | 48 | verify_task(task=task, port=local_container[1]) 49 | 50 | @tenacity.retry( 51 | stop=tenacity.stop_after_attempt(5), 52 | reraise=True, 53 | ) 54 | @require_torch 55 | @pytest.mark.parametrize( 56 | "repository_id", 57 | ["philschmid/custom-handler-test", "philschmid/custom-handler-distilbert"], 58 | ) 59 | @pytest.mark.parametrize("device", ["cpu"]) 60 | @pytest.mark.parametrize("framework", ["pytorch"]) 61 | @pytest.mark.parametrize("task", ["custom"]) 62 | @pytest.mark.usefixtures("local_container") 63 | def test_pt_container_custom_handler( 64 | self, local_container, task, device, repository_id 65 | ) -> None: 66 | 67 | verify_task( 68 | task=task, 69 | port=local_container[1], 70 | ) 71 | 72 | @tenacity.retry( 73 | stop=tenacity.stop_after_attempt(5), 74 | reraise=True, 75 | ) 76 | @require_torch 77 | @pytest.mark.parametrize( 78 | "repository_id", 79 | ["philschmid/custom-pipeline-text-classification"], 80 | ) 81 | @pytest.mark.parametrize("device", ["cpu"]) 82 | @pytest.mark.parametrize("framework", ["pytorch"]) 83 | @pytest.mark.parametrize("task", ["custom"]) 84 | @pytest.mark.usefixtures("local_container") 85 | def test_pt_container_legacy_custom_pipeline( 86 | self, local_container, repository_id, device, task 87 | ) -> None: 88 | 89 | verify_task(task=task, port=local_container[1]) 90 | -------------------------------------------------------------------------------- /tests/integ/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def validate_classification(result=None, snapshot=None): 5 | for idx, _ in enumerate(result): 6 | assert result[idx].keys() == snapshot[idx].keys() 7 | return True 8 | 9 | 10 | def validate_conversational(result=None, snapshot=None): 11 | assert len(result[0]["generated_text"]) >= len(snapshot) 12 | 13 | 14 | def validate_zero_shot_classification(result=None, snapshot=None): 15 | logging.info(f"Result: {result}") 16 | logging.info(f"Snapshot: {snapshot}") 17 | assert result.keys() == snapshot.keys() 18 | # assert result["labels"] == snapshot["labels"] 19 | # assert result["sequence"] == snapshot["sequence"] 20 | # for idx in range(len(result["scores"])): 21 | # assert result["scores"][idx] >= snapshot["scores"][idx] 22 | return True 23 | 24 | 25 | def validate_ner(result=None, snapshot=None): 26 | assert result[0].keys() == snapshot[0].keys() 27 | # for idx, _ in enumerate(result): 28 | # assert result[idx]["score"] >= snapshot[idx]["score"] 29 | # assert result[idx]["entity"] == snapshot[idx]["entity"] 30 | # assert result[idx]["entity"] == snapshot[idx]["entity"] 31 | return True 32 | 33 | 34 | def validate_question_answering(result=None, snapshot=None): 35 | assert result.keys() == snapshot.keys() 36 | # assert result["answer"] == snapshot["answer"] 37 | # assert result["score"] >= snapshot["score"] 38 | return True 39 | 40 | 41 | def validate_summarization(result=None, snapshot=None): 42 | assert result is not None 43 | return True 44 | 45 | 46 | def validate_text2text_generation(result=None, snapshot=None): 47 | assert result is not None 48 | return True 49 | 50 | 51 | def validate_translation(result=None, snapshot=None): 52 | assert result is not None 53 | return True 54 | 55 | 56 | def validate_text_generation(result=None, snapshot=None): 57 | assert result is not None 58 | return True 59 | 60 | 61 | def validate_feature_extraction(result=None, snapshot=None): 62 | assert result is not None 63 | return True 64 | 65 | 66 | def validate_fill_mask(result=None, snapshot=None): 67 | assert result is not None 68 | return True 69 | 70 | 71 | def validate_automatic_speech_recognition(result=None, snapshot=None): 72 | assert result is not None 73 | assert "text" in result 74 | return True 75 | 76 | 77 | def validate_object_detection(result=None, snapshot=None): 78 | assert result[0].keys() == snapshot[0].keys() 79 | return True 80 | 81 | 82 | def validate_text_to_image(result=None, snapshot=None): 83 | assert isinstance(result, snapshot) 84 | return True 85 | 86 | 87 | def validate_image_text_to_text(result=None, snapshot=None): 88 | assert isinstance(result, list) 89 | assert all(isinstance(d, dict) and d.keys() == {"input_text", "generated_text"} for d in result) 90 | return True 91 | 92 | 93 | def validate_custom(result=None, snapshot=None): 94 | logging.info(f"Validate custom task - result: {result}, snapshot: {snapshot}") 95 | assert result == snapshot 96 | return True 97 | -------------------------------------------------------------------------------- /tests/unit/test_optimum_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import pytest 5 | from transformers.testing_utils import require_torch 6 | 7 | from huggingface_inference_toolkit.optimum_utils import ( 8 | get_input_shapes, 9 | get_optimum_neuron_pipeline, 10 | is_optimum_neuron_available, 11 | ) 12 | from huggingface_inference_toolkit.utils import _load_repository_from_hf 13 | 14 | require_inferentia = pytest.mark.skipif( 15 | not is_optimum_neuron_available(), 16 | reason="Skipping tests, since optimum neuron is not available or not running on inf2 instances.", 17 | ) 18 | 19 | 20 | REMOTE_NOT_CONVERTED_MODEL = "hf-internal-testing/tiny-random-BertModel" 21 | REMOTE_CONVERTED_MODEL = "optimum/tiny_random_bert_neuron" 22 | TASK = "text-classification" 23 | 24 | 25 | @require_torch 26 | @require_inferentia 27 | def test_not_supported_task(): 28 | os.environ["HF_TASK"] = "not-supported-task" 29 | with pytest.raises(Exception): # noqa 30 | get_optimum_neuron_pipeline(task=TASK, target_dir=os.getcwd()) 31 | 32 | 33 | @require_torch 34 | @require_inferentia 35 | def test_get_input_shapes_from_file(): 36 | with tempfile.TemporaryDirectory() as tmpdirname: 37 | storage_folder = _load_repository_from_hf( 38 | repository_id=REMOTE_CONVERTED_MODEL, 39 | target_dir=tmpdirname, 40 | ) 41 | input_shapes = get_input_shapes(model_dir=storage_folder) 42 | assert input_shapes["batch_size"] == 1 43 | assert input_shapes["sequence_length"] == 32 44 | 45 | 46 | @require_torch 47 | @require_inferentia 48 | def test_get_input_shapes_from_env(): 49 | os.environ["HF_OPTIMUM_BATCH_SIZE"] = "4" 50 | os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = "32" 51 | with tempfile.TemporaryDirectory() as tmpdirname: 52 | storage_folder = _load_repository_from_hf( 53 | repository_id=REMOTE_NOT_CONVERTED_MODEL, 54 | target_dir=tmpdirname, 55 | ) 56 | input_shapes = get_input_shapes(model_dir=storage_folder) 57 | assert input_shapes["batch_size"] == 4 58 | assert input_shapes["sequence_length"] == 32 59 | 60 | 61 | @require_torch 62 | @require_inferentia 63 | def test_get_optimum_neuron_pipeline_from_converted_model(): 64 | with tempfile.TemporaryDirectory() as tmpdirname: 65 | os.system( 66 | f"optimum-cli export neuron --model philschmid/tiny-distilbert-classification --sequence_length 32 --batch_size 1 {tmpdirname}" 67 | ) 68 | pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=tmpdirname) 69 | r = pipe("This is a test") 70 | 71 | assert r[0]["score"] > 0.0 72 | assert isinstance(r[0]["label"], str) 73 | 74 | 75 | @require_torch 76 | @require_inferentia 77 | def test_get_optimum_neuron_pipeline_from_non_converted_model(): 78 | os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = "32" 79 | with tempfile.TemporaryDirectory() as tmpdirname: 80 | storage_folder = _load_repository_from_hf( 81 | repository_id=REMOTE_NOT_CONVERTED_MODEL, 82 | target_dir=tmpdirname, 83 | ) 84 | pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=storage_folder) 85 | r = pipe("This is a test") 86 | 87 | assert r[0]["score"] > 0.0 88 | assert isinstance(r[0]["label"], str) 89 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from setuptools import find_packages, setup 4 | 5 | # We don't declare our dependency on transformers here because we build with 6 | # different packages for different variants 7 | 8 | VERSION = "0.5.6" 9 | 10 | # Ubuntu packages 11 | # libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev 12 | # ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg 13 | # libavcodec-extra : libavcodec-extra includes additional codecs for ffmpeg 14 | 15 | install_requires = [ 16 | # Due to an error affecting kenlm and cmake (see https://github.com/kpu/kenlm/pull/464) 17 | # Also see the transformers patch for it https://github.com/huggingface/transformers/pull/37091 18 | "kenlm@git+https://github.com/kpu/kenlm@ba83eafdce6553addd885ed3da461bb0d60f8df7", 19 | "transformers[sklearn,sentencepiece,audio,vision]==4.51.3", 20 | "huggingface_hub[hf_transfer]==0.30.2", 21 | # vision 22 | "Pillow", 23 | "librosa", 24 | # speech + torchaudio 25 | "pyctcdecode>=0.3.0", 26 | "phonemizer", 27 | "ffmpeg", 28 | # web api 29 | "starlette", 30 | "uvicorn", 31 | "pandas", 32 | "orjson", 33 | "einops", 34 | ] 35 | 36 | extras = {} 37 | 38 | extras["st"] = ["sentence_transformers==4.0.2"] 39 | extras["diffusers"] = ["diffusers==0.33.1", "accelerate==1.6.0"] 40 | # Includes `peft` as PEFT requires `torch` so having `peft` as a core dependency 41 | # means that `torch` will be installed even if the `torch` extra is not specified. 42 | extras["torch"] = ["torch==2.5.1", "torchvision", "torchaudio", "peft==0.15.1"] 43 | extras["test"] = [ 44 | "pytest==7.2.1", 45 | "pytest-xdist", 46 | "parameterized", 47 | "psutil", 48 | "datasets", 49 | "pytest-sugar", 50 | "mock==2.0.0", 51 | "docker", 52 | "requests", 53 | "tenacity", 54 | ] 55 | extras["quality"] = ["isort", "ruff"] 56 | extras["inf2"] = ["optimum-neuron"] 57 | extras["google"] = ["google-cloud-storage", "crcmod==1.7"] 58 | 59 | setup( 60 | name="huggingface-inference-toolkit", 61 | version=VERSION, 62 | author="Hugging Face", 63 | description="Hugging Face Inference Toolkit is for serving 🤗 Transformers models in containers.", 64 | url="https://github.com/huggingface/huggingface-inference-toolkit", 65 | package_dir={"": "src"}, 66 | packages=find_packages(where="src"), 67 | install_requires=install_requires, 68 | extras_require=extras, 69 | entry_points={"console_scripts": "serve=sagemaker_huggingface_inference_toolkit.serving:main"}, 70 | python_requires=">=3.9", 71 | license="Apache License 2.0", 72 | classifiers=[ 73 | "Development Status :: 5 - Production/Stable", 74 | "Intended Audience :: Developers", 75 | "Intended Audience :: Education", 76 | "Intended Audience :: Science/Research", 77 | "License :: OSI Approved :: Apache Software License", 78 | "Operating System :: OS Independent", 79 | "Programming Language :: Python :: 3", 80 | "Programming Language :: Python :: 3.9", 81 | "Programming Language :: Python :: 3.10", 82 | "Programming Language :: Python :: 3.11", 83 | "Programming Language :: Python :: 3.12", 84 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 85 | ], 86 | ) 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Docker project generated files to ignore 2 | # if you want to ignore files created by your editor/tools, 3 | # please consider a global .gitignore https://help.github.com/articles/ignoring-files 4 | .gitignore 5 | .egg-info 6 | .ruff_cache 7 | .vagrant* 8 | .hcl 9 | .terraform.lock.hcl 10 | .terraform 11 | pip-unpack-* 12 | __pycache__ 13 | bin 14 | docker/docker 15 | .*.swp 16 | a.out 17 | *.orig 18 | build_src 19 | .flymake* 20 | .idea 21 | .DS_Store 22 | docs/_build 23 | docs/_static 24 | docs/_templates 25 | .gopath/ 26 | .dotcloud 27 | *.test 28 | bundles/ 29 | .hg/ 30 | .git/ 31 | vendor/pkg/ 32 | pyenv 33 | Vagrantfile 34 | # Byte-compiled / optimized / DLL files 35 | __pycache__/ 36 | *.py[cod] 37 | *$py.class 38 | .vscode 39 | .make 40 | tox.ini 41 | 42 | # C extensions 43 | *.so 44 | 45 | # Distribution / packaging 46 | .Python 47 | build/ 48 | develop-eggs/ 49 | dist/ 50 | downloads/ 51 | eggs/ 52 | .eggs/ 53 | lib/ 54 | lib64/ 55 | parts/ 56 | sdist/ 57 | var/ 58 | wheels/ 59 | share/python-wheels/ 60 | *.egg-info/ 61 | .installed.cfg 62 | *.egg 63 | MANIFEST 64 | 65 | # PyInstaller 66 | # Usually these files are written by a python script from a template 67 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 68 | *.manifest 69 | *.spec 70 | 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | 75 | # Unit test / coverage reports 76 | htmlcov/ 77 | .tox/ 78 | .nox/ 79 | .coverage 80 | .coverage.* 81 | .cache 82 | nosetests.xml 83 | coverage.xml 84 | *.cover 85 | *.py,cover 86 | .hypothesis/ 87 | .pytest_cache/ 88 | cover/ 89 | 90 | # Translations 91 | *.mo 92 | *.pot 93 | 94 | # Django stuff: 95 | *.log 96 | local_settings.py 97 | db.sqlite3 98 | db.sqlite3-journal 99 | 100 | # Flask stuff: 101 | instance/ 102 | .webassets-cache 103 | 104 | # Scrapy stuff: 105 | .scrapy 106 | 107 | # Sphinx documentation 108 | docs/_build/ 109 | 110 | # PyBuilder 111 | .pybuilder/ 112 | target/ 113 | 114 | # Jupyter Notebook 115 | .ipynb_checkpoints 116 | 117 | # IPython 118 | profile_default/ 119 | ipython_config.py 120 | 121 | # pyenv 122 | # For a library or package, you might want to ignore these files since the code is 123 | # intended to run in multiple environments; otherwise, check them in: 124 | # .python-version 125 | 126 | # pipenv 127 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 128 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 129 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 130 | # install all needed dependencies. 131 | #Pipfile.lock 132 | 133 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 134 | __pypackages__/ 135 | 136 | # Celery stuff 137 | celerybeat-schedule 138 | celerybeat.pid 139 | 140 | # SageMath parsed files 141 | *.sage.py 142 | 143 | # Environments 144 | .env 145 | .venv 146 | env/ 147 | venv/ 148 | ENV/ 149 | env.bak/ 150 | venv.bak/ 151 | 152 | # Spyder project settings 153 | .spyderproject 154 | .spyproject 155 | 156 | # Rope project settings 157 | .ropeproject 158 | 159 | # mkdocs documentation 160 | /site 161 | 162 | # mypy 163 | .mypy_cache/ 164 | .dmypy.json 165 | dmypy.json 166 | 167 | # Pyre type checker 168 | .pyre/ 169 | 170 | # pytype static type analyzer 171 | .pytype/ 172 | 173 | # Cython debug symbols 174 | cython_debug/ 175 | 176 | .vscode/settings.json 177 | .sagemaker 178 | model 179 | tests/tmp 180 | tmp/ 181 | act.sh 182 | .act 183 | tmp* 184 | log-* -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/diffusers_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | from typing import Union 3 | 4 | from transformers.utils.import_utils import is_torch_bf16_gpu_available 5 | 6 | from huggingface_inference_toolkit.logging import logger 7 | 8 | _diffusers = importlib.util.find_spec("diffusers") is not None 9 | 10 | 11 | def is_diffusers_available(): 12 | return _diffusers 13 | 14 | 15 | if is_diffusers_available(): 16 | import torch 17 | from diffusers import ( 18 | AutoPipelineForText2Image, 19 | DPMSolverMultistepScheduler, 20 | StableDiffusionPipeline, 21 | ) 22 | 23 | 24 | class IEAutoPipelineForText2Image: 25 | def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs): # needs "cuda" for GPU 26 | dtype = torch.float32 27 | if device == "cuda": 28 | dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16 29 | device_map = "balanced" if device == "cuda" else None 30 | 31 | self.pipeline = AutoPipelineForText2Image.from_pretrained( 32 | model_dir, torch_dtype=dtype, device_map=device_map, **kwargs 33 | ) 34 | # try to use DPMSolverMultistepScheduler 35 | if isinstance(self.pipeline, StableDiffusionPipeline): 36 | try: 37 | self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config) 38 | except Exception: 39 | pass 40 | 41 | def __call__( 42 | self, 43 | prompt, 44 | **kwargs, 45 | ): 46 | if "prompt" in kwargs: 47 | logger.warning( 48 | "prompt has been provided twice, both via arg and kwargs, so the `prompt` arg will be used " 49 | "instead, and the `prompt` in kwargs will be discarded." 50 | ) 51 | kwargs.pop("prompt") 52 | 53 | # diffusers doesn't support seed but rather the generator kwarg 54 | # see: https://github.com/huggingface/api-inference-community/blob/8e577e2d60957959ba02f474b2913d84a9086b82/docker_images/diffusers/app/pipelines/text_to_image.py#L172-L176 55 | if "seed" in kwargs: 56 | seed = int(kwargs["seed"]) 57 | generator = torch.Generator().manual_seed(seed) 58 | kwargs["generator"] = generator 59 | kwargs.pop("seed") 60 | 61 | # TODO: add support for more images (Reason is correct output) 62 | if "num_images_per_prompt" in kwargs: 63 | kwargs.pop("num_images_per_prompt") 64 | logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.") 65 | 66 | if "target_size" in kwargs: 67 | kwargs["height"] = kwargs["target_size"].pop("height") 68 | kwargs["width"] = kwargs["target_size"].pop("width") 69 | kwargs.pop("target_size") 70 | 71 | if kwargs.get("height") != kwargs.get("width"): 72 | raise ValueError( 73 | f"Provided `height={kwargs.get('height')}` and `width={kwargs.get('width')}`, but either both must have a value or both must be None (or not provided)." 74 | ) 75 | 76 | if "output_type" in kwargs and kwargs["output_type"] != "pil": 77 | kwargs.pop("output_type") 78 | logger.warning("The `output_type` cannot be modified, and PIL will be used by default instead.") 79 | 80 | # Call pipeline with parameters 81 | out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs) 82 | 83 | return out.images[0] 84 | 85 | 86 | DIFFUSERS_TASKS = { 87 | "text-to-image": IEAutoPipelineForText2Image, 88 | } 89 | 90 | 91 | def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **kwargs): 92 | """Get a pipeline for Diffusers models.""" 93 | device = "cuda" if device == 0 else "cpu" 94 | pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device, **kwargs) 95 | return pipeline 96 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/sentence_transformers_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | from typing import Any, Dict, List, Tuple, Union 3 | 4 | try: 5 | from typing import Literal 6 | except ImportError: 7 | from typing_extensions import Literal 8 | 9 | _sentence_transformers = importlib.util.find_spec("sentence_transformers") is not None 10 | 11 | 12 | def is_sentence_transformers_available(): 13 | return _sentence_transformers 14 | 15 | 16 | if is_sentence_transformers_available(): 17 | from sentence_transformers import CrossEncoder, SentenceTransformer, util 18 | 19 | 20 | class SentenceSimilarityPipeline: 21 | def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: Any) -> None: 22 | # `device` needs to be set to "cuda" for GPU 23 | self.model = SentenceTransformer(model_dir, device=device, **kwargs) 24 | 25 | def __call__(self, source_sentence: str, sentences: List[str]) -> Dict[str, float]: 26 | embeddings1 = self.model.encode(source_sentence, convert_to_tensor=True) 27 | embeddings2 = self.model.encode(sentences, convert_to_tensor=True) 28 | similarities = util.pytorch_cos_sim(embeddings1, embeddings2).tolist()[0] 29 | return {"similarities": similarities} 30 | 31 | 32 | class SentenceEmbeddingPipeline: 33 | def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: Any) -> None: 34 | # `device` needs to be set to "cuda" for GPU 35 | self.model = SentenceTransformer(model_dir, device=device, **kwargs) 36 | 37 | def __call__(self, sentences: Union[str, List[str]]) -> Dict[str, List[float]]: 38 | embeddings = self.model.encode(sentences).tolist() 39 | return {"embeddings": embeddings} 40 | 41 | 42 | class SentenceRankingPipeline: 43 | def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: Any) -> None: 44 | # `device` needs to be set to "cuda" for GPU 45 | self.model = CrossEncoder(model_dir, device=device, **kwargs) 46 | 47 | def __call__( 48 | self, 49 | sentences: Union[Tuple[str, str], List[str], List[List[str]], List[Tuple[str, str]], None] = None, 50 | query: Union[str, None] = None, 51 | texts: Union[List[str], None] = None, 52 | return_documents: bool = False, 53 | ) -> Union[Dict[str, List[float]], List[Dict[Literal["index", "score", "text"], Any]]]: 54 | if all(x is not None for x in [sentences, query, texts]): 55 | raise ValueError( 56 | f"The provided payload contains {sentences=} (i.e. 'inputs'), {query=}, and {texts=}" 57 | " but all of those cannot be provided, you should provide either only 'sentences' i.e. 'inputs'" 58 | " of both 'query' and 'texts' to run the ranking task." 59 | ) 60 | 61 | if all(x is None for x in [sentences, query, texts]): 62 | raise ValueError( 63 | "No inputs have been provided within the input payload, make sure that the input payload" 64 | " contains either 'sentences' i.e. 'inputs', or both 'query' and 'texts' to run the ranking task." 65 | ) 66 | 67 | if sentences is not None: 68 | scores = self.model.predict(sentences).tolist() 69 | return {"scores": scores} 70 | 71 | if query is None or not isinstance(query, str): 72 | raise ValueError(f"Provided {query=} but a non-empty string should be provided instead.") 73 | 74 | if texts is None or not isinstance(texts, list) or not all(isinstance(text, str) for text in texts): 75 | raise ValueError(f"Provided {texts=}, but a list of non-empty strings should be provided instead.") 76 | 77 | scores = self.model.rank(query, texts, return_documents=return_documents) 78 | # rename "corpus_id" key to "index" for all scores to match TEI 79 | for score in scores: 80 | score["index"] = score.pop("corpus_id") # type: ignore 81 | return scores # type: ignore 82 | 83 | 84 | SENTENCE_TRANSFORMERS_TASKS = { 85 | "sentence-similarity": SentenceSimilarityPipeline, 86 | "sentence-embeddings": SentenceEmbeddingPipeline, 87 | "sentence-ranking": SentenceRankingPipeline, 88 | } 89 | 90 | 91 | def get_sentence_transformers_pipeline(task=None, model_dir=None, device=-1, **kwargs): 92 | device = "cuda" if device == 0 else "cpu" 93 | 94 | kwargs.pop("tokenizer", None) 95 | kwargs.pop("framework", None) 96 | 97 | if task not in SENTENCE_TRANSFORMERS_TASKS: 98 | raise ValueError(f"Unknown task {task}. Available tasks are: {', '.join(SENTENCE_TRANSFORMERS_TASKS.keys())}") 99 | return SENTENCE_TRANSFORMERS_TASKS[task](model_dir=model_dir, device=device, **kwargs) 100 | -------------------------------------------------------------------------------- /tests/unit/test_handler.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from typing import Dict 3 | 4 | import pytest 5 | from transformers.testing_utils import require_tf, require_torch 6 | 7 | from huggingface_inference_toolkit.handler import ( 8 | HuggingFaceHandler, 9 | get_inference_handler_either_custom_or_default_handler, 10 | ) 11 | from huggingface_inference_toolkit.utils import ( 12 | _is_gpu_available, 13 | _load_repository_from_hf, 14 | ) 15 | 16 | TASK = "text-classification" 17 | MODEL = "hf-internal-testing/tiny-random-distilbert" 18 | 19 | 20 | # defined as fixture because it's modified on `pop` 21 | @pytest.fixture 22 | def input_data(): 23 | return {"inputs": "My name is Wolfgang and I live in Berlin"} 24 | 25 | 26 | @require_torch 27 | def test_pt_get_device() -> None: 28 | import torch 29 | 30 | with tempfile.TemporaryDirectory() as tmpdirname: 31 | # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py 32 | storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") 33 | h = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK) 34 | if torch.cuda.is_available(): 35 | assert h.pipeline.model.device == torch.device(type="cuda", index=0) 36 | else: 37 | assert h.pipeline.model.device == torch.device(type="cpu") 38 | 39 | 40 | @require_torch 41 | def test_pt_predict_call(input_data: Dict[str, str]) -> None: 42 | with tempfile.TemporaryDirectory() as tmpdirname: 43 | # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py 44 | storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") 45 | h = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK) 46 | 47 | prediction = h(input_data) 48 | assert "label" in prediction[0] 49 | assert "score" in prediction[0] 50 | 51 | 52 | @require_torch 53 | def test_pt_custom_pipeline(input_data: Dict[str, str]) -> None: 54 | with tempfile.TemporaryDirectory() as tmpdirname: 55 | storage_dir = _load_repository_from_hf( 56 | "philschmid/custom-pipeline-text-classification", 57 | tmpdirname, 58 | framework="pytorch", 59 | ) 60 | h = get_inference_handler_either_custom_or_default_handler(str(storage_dir), task="custom") 61 | assert h(input_data) == input_data 62 | 63 | 64 | @require_torch 65 | def test_pt_sentence_transformers_pipeline(input_data: Dict[str, str]) -> None: 66 | with tempfile.TemporaryDirectory() as tmpdirname: 67 | storage_dir = _load_repository_from_hf( 68 | "sentence-transformers/all-MiniLM-L6-v2", tmpdirname, framework="pytorch" 69 | ) 70 | h = get_inference_handler_either_custom_or_default_handler(str(storage_dir), task="sentence-embeddings") 71 | pred = h(input_data) 72 | assert isinstance(pred["embeddings"], list) 73 | 74 | 75 | @require_tf 76 | def test_tf_get_device(): 77 | with tempfile.TemporaryDirectory() as tmpdirname: 78 | # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py 79 | storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="tensorflow") 80 | h = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK) 81 | if _is_gpu_available(): 82 | assert h.pipeline.device == 0 83 | else: 84 | assert h.pipeline.device == -1 85 | 86 | 87 | @require_tf 88 | def test_tf_predict_call(input_data: Dict[str, str]) -> None: 89 | with tempfile.TemporaryDirectory() as tmpdirname: 90 | # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py 91 | storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="tensorflow") 92 | handler = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK, framework="tf") 93 | 94 | prediction = handler(input_data) 95 | assert "label" in prediction[0] 96 | assert "score" in prediction[0] 97 | 98 | 99 | @require_tf 100 | def test_tf_custom_pipeline(input_data: Dict[str, str]) -> None: 101 | with tempfile.TemporaryDirectory() as tmpdirname: 102 | storage_dir = _load_repository_from_hf( 103 | "philschmid/custom-pipeline-text-classification", 104 | tmpdirname, 105 | framework="tensorflow", 106 | ) 107 | h = get_inference_handler_either_custom_or_default_handler(str(storage_dir), task="custom") 108 | assert h(input_data) == input_data 109 | 110 | 111 | @require_tf 112 | def test_tf_sentence_transformers_pipeline(): 113 | # TODO should fail! because TF is not supported yet 114 | with tempfile.TemporaryDirectory() as tmpdirname: 115 | storage_dir = _load_repository_from_hf( 116 | "sentence-transformers/all-MiniLM-L6-v2", tmpdirname, framework="tensorflow" 117 | ) 118 | with pytest.raises(Exception) as _exc_info: 119 | get_inference_handler_either_custom_or_default_handler(str(storage_dir), task="sentence-embeddings") 120 | -------------------------------------------------------------------------------- /tests/unit/test_sentence_transformers.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import pytest 4 | from transformers.testing_utils import require_torch 5 | 6 | from huggingface_inference_toolkit.sentence_transformers_utils import ( 7 | SentenceEmbeddingPipeline, 8 | get_sentence_transformers_pipeline, 9 | ) 10 | from huggingface_inference_toolkit.utils import ( 11 | _load_repository_from_hf, 12 | get_pipeline, 13 | ) 14 | 15 | 16 | @require_torch 17 | def test_get_sentence_transformers_pipeline(): 18 | with tempfile.TemporaryDirectory() as tmpdirname: 19 | storage_dir = _load_repository_from_hf("sentence-transformers/all-MiniLM-L6-v2", tmpdirname) 20 | pipe = get_pipeline("sentence-embeddings", storage_dir.as_posix()) 21 | assert isinstance(pipe, SentenceEmbeddingPipeline) 22 | 23 | 24 | @require_torch 25 | def test_sentence_embedding_task(): 26 | with tempfile.TemporaryDirectory() as tmpdirname: 27 | storage_dir = _load_repository_from_hf("sentence-transformers/all-MiniLM-L6-v2", tmpdirname) 28 | pipe = get_sentence_transformers_pipeline("sentence-embeddings", storage_dir.as_posix()) 29 | res = pipe(sentences="Lets create an embedding") 30 | assert isinstance(res["embeddings"], list) 31 | res = pipe(sentences=["Lets create an embedding", "Lets create another embedding"]) 32 | assert isinstance(res["embeddings"], list) 33 | assert len(res["embeddings"]) == 2 34 | 35 | 36 | @require_torch 37 | def test_sentence_similarity(): 38 | with tempfile.TemporaryDirectory() as tmpdirname: 39 | storage_dir = _load_repository_from_hf("sentence-transformers/all-MiniLM-L6-v2", tmpdirname) 40 | pipe = get_sentence_transformers_pipeline("sentence-similarity", storage_dir.as_posix()) 41 | res = pipe(source_sentence="Lets create an embedding", sentences=["Lets create an embedding"]) 42 | assert isinstance(res["similarities"], list) 43 | 44 | 45 | @require_torch 46 | def test_sentence_ranking(): 47 | with tempfile.TemporaryDirectory() as tmpdirname: 48 | storage_dir = _load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname) 49 | pipe = get_sentence_transformers_pipeline("sentence-ranking", storage_dir.as_posix()) 50 | res = pipe( 51 | sentences=[ 52 | ["Lets create an embedding", "Lets create another embedding"], 53 | ["Lets create an embedding", "Lets create another embedding"], 54 | ] 55 | ) 56 | assert isinstance(res["scores"], list) 57 | res = pipe(sentences=["Lets create an embedding", "Lets create an embedding"]) 58 | assert isinstance(res["scores"], float) 59 | 60 | 61 | @require_torch 62 | def test_sentence_ranking_tei(): 63 | with tempfile.TemporaryDirectory() as tmpdirname: 64 | storage_dir = _load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname, framework="pytorch") 65 | pipe = get_sentence_transformers_pipeline("sentence-ranking", storage_dir.as_posix()) 66 | res = pipe( 67 | query="Lets create an embedding", 68 | texts=["Lets create an embedding", "I like noodles"], 69 | ) 70 | assert isinstance(res, list) 71 | assert all(r.keys() == {"index", "score"} for r in res) 72 | 73 | res = pipe( 74 | query="Lets create an embedding", 75 | texts=["Lets create an embedding", "I like noodles"], 76 | return_documents=True, 77 | ) 78 | assert isinstance(res, list) 79 | assert all(r.keys() == {"index", "score", "text"} for r in res) 80 | 81 | 82 | @require_torch 83 | def test_sentence_ranking_validation_errors(): 84 | with tempfile.TemporaryDirectory() as tmpdirname: 85 | storage_dir = _load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname, framework="pytorch") 86 | pipe = get_sentence_transformers_pipeline("sentence-ranking", storage_dir.as_posix()) 87 | 88 | with pytest.raises( 89 | ValueError, 90 | match=( 91 | "you should provide either only 'sentences' i.e. 'inputs' " 92 | "of both 'query' and 'texts' to run the ranking task." 93 | ), 94 | ): 95 | pipe( 96 | sentences="Lets create an embedding", 97 | query="Lets create an embedding", 98 | texts=["Lets create an embedding", "I like noodles"], 99 | ) 100 | 101 | with pytest.raises( 102 | ValueError, 103 | match=( 104 | "No inputs have been provided within the input payload, make sure that the input payload " 105 | "contains either 'sentences' i.e. 'inputs', or both 'query' and 'texts'" 106 | ), 107 | ): 108 | pipe(sentences=None, query=None, texts=None) 109 | 110 | with pytest.raises( 111 | ValueError, 112 | match=("Provided texts=None, but a list of non-empty strings should be provided instead."), 113 | ): 114 | pipe(sentences=None, query="Lets create an embedding", texts=None) 115 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/optimum_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import os 3 | 4 | from huggingface_inference_toolkit.logging import logger 5 | 6 | _optimum_neuron = False 7 | if importlib.util.find_spec("optimum") is not None: 8 | if importlib.util.find_spec("optimum.neuron") is not None: 9 | _optimum_neuron = True 10 | 11 | 12 | def is_optimum_neuron_available(): 13 | return _optimum_neuron 14 | 15 | 16 | def get_input_shapes(model_dir): 17 | """Method to get input shapes from model config file. If config file is not present, default values are returned.""" 18 | from transformers import AutoConfig 19 | 20 | input_shapes = {} 21 | input_shapes_available = False 22 | # try to get input shapes from config file 23 | try: 24 | config = AutoConfig.from_pretrained(model_dir) 25 | if hasattr(config, "neuron"): 26 | # check if static batch size and sequence length are available 27 | if config.neuron.get("static_batch_size", None) and config.neuron.get( 28 | "static_sequence_length", None 29 | ): 30 | input_shapes["batch_size"] = config.neuron["static_batch_size"] 31 | input_shapes["sequence_length"] = config.neuron[ 32 | "static_sequence_length" 33 | ] 34 | input_shapes_available = True 35 | logger.info( 36 | f"Input shapes found in config file. Using input shapes from config with batch size {input_shapes['batch_size']} and sequence length {input_shapes['sequence_length']}" 37 | ) 38 | else: 39 | # Add warning if environment variables are set but will be ignored 40 | if os.environ.get("HF_OPTIMUM_BATCH_SIZE", None) is not None: 41 | logger.warning( 42 | "HF_OPTIMUM_BATCH_SIZE environment variable is set. Environment variable will be ignored and input shapes from config file will be used." 43 | ) 44 | if os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) is not None: 45 | logger.warning( 46 | "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is set. Environment variable will be ignored and input shapes from config file will be used." 47 | ) 48 | except Exception: 49 | input_shapes_available = False 50 | 51 | # return input shapes if available 52 | if input_shapes_available: 53 | return input_shapes 54 | 55 | # extract input shapes from environment variables 56 | sequence_length = os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) 57 | if sequence_length is None: 58 | raise ValueError( 59 | "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is not set. Please set HF_OPTIMUM_SEQUENCE_LENGTH to a positive integer." 60 | ) 61 | 62 | if not int(sequence_length) > 0: 63 | raise ValueError( 64 | f"HF_OPTIMUM_SEQUENCE_LENGTH must be set to a positive integer. Current value is {sequence_length}" 65 | ) 66 | batch_size = os.environ.get("HF_OPTIMUM_BATCH_SIZE", 1) 67 | logger.info( 68 | f"Using input shapes from environment variables with batch size {batch_size} and sequence length {sequence_length}" 69 | ) 70 | return {"batch_size": int(batch_size), "sequence_length": int(sequence_length)} 71 | 72 | 73 | def get_optimum_neuron_pipeline(task, model_dir): 74 | """Method to get optimum neuron pipeline for a given task. Method checks if task is supported by optimum neuron and if required environment variables are set, in case model is not converted. If all checks pass, optimum neuron pipeline is returned. If checks fail, an error is raised.""" 75 | logger.info("Getting optimum neuron pipeline.") 76 | from optimum.neuron.pipelines.transformers.base import ( 77 | NEURONX_SUPPORTED_TASKS, 78 | pipeline, 79 | ) 80 | from optimum.neuron.utils import NEURON_FILE_NAME 81 | 82 | # convert from os.path or path 83 | if not isinstance(model_dir, str): 84 | model_dir = str(model_dir) 85 | 86 | # check if task is sentence-embeddings and convert to feature-extraction, as sentence-embeddings is supported in feature-extraction pipeline 87 | if task == "sentence-embeddings": 88 | task = "feature-extraction" 89 | 90 | # check task support 91 | if task not in NEURONX_SUPPORTED_TASKS: 92 | raise ValueError( 93 | f"Task {task} is not supported by optimum neuron and inf2. Supported tasks are: {list(NEURONX_SUPPORTED_TASKS.keys())}" 94 | ) 95 | 96 | # check if model is already converted and has input shapes available 97 | export = True 98 | if NEURON_FILE_NAME in os.listdir(model_dir): 99 | export = False 100 | if export: 101 | logger.info( 102 | "Model is not converted. Checking if required environment variables are set and converting model." 103 | ) 104 | 105 | # get static input shapes to run inference 106 | input_shapes = get_input_shapes(model_dir) 107 | # set NEURON_RT_NUM_CORES to 1 to avoid conflicts with multiple HTTP workers 108 | # TODO: Talk to optimum team what are the best options for encoder models to run on 2 neuron cores 109 | # os.environ["NEURON_RT_NUM_CORES"] = "1" 110 | # get optimum neuron pipeline 111 | neuron_pipe = pipeline( 112 | task, model=model_dir, export=export, input_shapes=input_shapes 113 | ) 114 | return neuron_pipe 115 | -------------------------------------------------------------------------------- /tests/integ/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import socket 5 | import time 6 | 7 | import docker 8 | import pytest 9 | import tenacity 10 | from transformers.testing_utils import _run_slow_tests 11 | 12 | from huggingface_inference_toolkit.utils import _load_repository_from_hf 13 | from tests.integ.config import task2model 14 | 15 | HF_HUB_CACHE = os.environ.get("HF_HUB_CACHE", "/home/ubuntu/.cache/huggingface/hub") 16 | IS_GPU = _run_slow_tests 17 | DEVICE = "gpu" if IS_GPU else "cpu" 18 | 19 | 20 | @tenacity.retry( 21 | retry=tenacity.retry_if_exception(docker.errors.APIError), 22 | stop=tenacity.stop_after_attempt(10), 23 | ) 24 | @pytest.fixture(scope="function") 25 | def remote_container(device, task, framework): 26 | time.sleep(random.randint(1, 5)) 27 | # client = docker.DockerClient(base_url='unix://var/run/docker.sock') 28 | client = docker.from_env() 29 | container_name = f"integration-test-{framework}-{task}-{device}" 30 | container_image = f"integration-test-{framework}:{device}" 31 | port = random.randint(5000, 9000) 32 | model = task2model[task][framework] 33 | 34 | # check if port is already open 35 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 36 | while sock.connect_ex(("localhost", port)) == 0: 37 | logging.debug(f"Port {port} is already being used; getting a new one...") 38 | port = random.randint(5000, 9000) 39 | 40 | logging.debug(f"Image: {container_image}") 41 | logging.debug(f"Port: {port}") 42 | 43 | device_request = ( 44 | [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] 45 | if device == "gpu" 46 | else [] 47 | ) 48 | 49 | yield client.containers.run( 50 | image=container_image, 51 | name=container_name, 52 | ports={"5000": port}, 53 | environment={"HF_MODEL_ID": model, "HF_TASK": task, "CUDA_LAUNCH_BLOCKING": 1}, 54 | detach=True, 55 | # GPU 56 | device_requests=device_request, 57 | ), port 58 | 59 | # Teardown 60 | previous = client.containers.get(container_name) 61 | logs = previous.logs().decode("utf-8") 62 | logging.info(f"Container logs:\n{logs}") 63 | previous.stop() 64 | previous.remove() 65 | 66 | 67 | @tenacity.retry(stop=tenacity.stop_after_attempt(10), reraise=True) 68 | @pytest.fixture(scope="function") 69 | def local_container(device, task, repository_id, framework): 70 | try: 71 | time.sleep(random.randint(1, 5)) 72 | if not (task == "custom"): 73 | model = task2model[task][framework] 74 | id = task 75 | else: 76 | model = repository_id 77 | id = random.randint(1, 1000) 78 | 79 | env = { 80 | "HF_MODEL_DIR": "/opt/huggingface/model", 81 | "HF_TASK": task, 82 | } 83 | 84 | logging.info(f"Starting container with model: {model}") 85 | 86 | if not model: 87 | message = f"No model supported for {framework}" 88 | logging.error(message) 89 | raise ValueError(message) 90 | 91 | logging.info(f"Starting container with Model = {model}") 92 | client = docker.from_env() 93 | container_name = f"integration-test-{framework}-{id}-{device}" 94 | container_image = f"integration-test-{framework}:{device}" 95 | 96 | port = random.randint(5000, 9000) 97 | 98 | # check if port is already open 99 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 100 | while sock.connect_ex(("localhost", port)) == 0: 101 | logging.debug(f"Port {port} is already being used; getting a new one...") 102 | port = random.randint(5000, 9000) 103 | 104 | logging.debug(f"Image: {container_image}") 105 | logging.debug(f"Port: {port}") 106 | 107 | device_request = ( 108 | [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] 109 | if device == "gpu" 110 | else None 111 | ) 112 | if device == "inf2": 113 | devices = { 114 | "/dev/neuron0": { 115 | "PathInContainer": "/dev/neuron0", 116 | "CgroupPermissions": "rwm", 117 | } 118 | } 119 | env["HF_OPTIMUM_BATCH_SIZE"] = 1 120 | env["HF_OPTIMUM_SEQUENCE_LENGTH"] = 128 121 | else: 122 | devices = None 123 | 124 | object_id = model.replace("/", "--") 125 | model_dir = f"{HF_HUB_CACHE}/{object_id}" 126 | 127 | _storage_dir = _load_repository_from_hf( 128 | repository_id=model, target_dir=model_dir 129 | ) 130 | 131 | yield client.containers.run( 132 | container_image, 133 | name=container_name, 134 | ports={"5000": port}, 135 | environment=env, 136 | volumes={model_dir: {"bind": "/opt/huggingface/model", "mode": "ro"}}, 137 | detach=True, 138 | # GPU 139 | device_requests=device_request, 140 | # INF2 141 | devices=devices, 142 | ), port 143 | 144 | # Teardown 145 | previous = client.containers.get(container_name) 146 | time.sleep(5) 147 | logs = previous.logs().decode("utf-8") 148 | logging.info(f"Container logs:\n{logs}") 149 | previous.stop() 150 | previous.remove() 151 | except Exception as exception: 152 | logging.error(f"Error starting container: {str(exception)}") 153 | raise exception 154 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/webservice_starlette.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | from pathlib import Path 4 | from time import perf_counter 5 | 6 | import orjson 7 | from starlette.applications import Starlette 8 | from starlette.responses import PlainTextResponse, Response 9 | from starlette.routing import Route 10 | 11 | from huggingface_inference_toolkit.async_utils import MAX_CONCURRENT_THREADS, MAX_THREADS_GUARD, async_handler_call 12 | from huggingface_inference_toolkit.const import ( 13 | HF_FRAMEWORK, 14 | HF_HUB_TOKEN, 15 | HF_MODEL_DIR, 16 | HF_MODEL_ID, 17 | HF_REVISION, 18 | HF_TASK, 19 | ) 20 | from huggingface_inference_toolkit.handler import ( 21 | get_inference_handler_either_custom_or_default_handler, 22 | ) 23 | from huggingface_inference_toolkit.logging import logger 24 | from huggingface_inference_toolkit.serialization.base import ContentType 25 | from huggingface_inference_toolkit.serialization.json_utils import Jsoner 26 | from huggingface_inference_toolkit.utils import ( 27 | _load_repository_from_hf, 28 | convert_params_to_int_or_bool, 29 | ) 30 | from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs 31 | 32 | 33 | async def prepare_model_artifacts(): 34 | global inference_handler 35 | # 1. check if model artifacts available in HF_MODEL_DIR 36 | if len(list(Path(HF_MODEL_DIR).glob("**/*"))) <= 0: 37 | # 2. if not available, try to load from HF_MODEL_ID 38 | if HF_MODEL_ID is not None: 39 | _load_repository_from_hf( 40 | repository_id=HF_MODEL_ID, 41 | target_dir=HF_MODEL_DIR, 42 | framework=HF_FRAMEWORK, 43 | revision=HF_REVISION, 44 | hf_hub_token=HF_HUB_TOKEN, 45 | ) 46 | # 3. check if in Vertex AI environment and load from GCS 47 | # If artifactUri not on Model Creation not set returns an empty string 48 | elif len(os.environ.get("AIP_STORAGE_URI", "")) > 0: 49 | _load_repository_from_gcs( 50 | os.environ["AIP_STORAGE_URI"], target_dir=HF_MODEL_DIR 51 | ) 52 | # 4. if not available, raise error 53 | else: 54 | raise ValueError( 55 | f"""Can't initialize model. 56 | Please set env HF_MODEL_DIR or provider a HF_MODEL_ID. 57 | Provided values are: 58 | HF_MODEL_DIR: {HF_MODEL_DIR} and HF_MODEL_ID:{HF_MODEL_ID}""" 59 | ) 60 | 61 | logger.info(f"Initializing model from directory:{HF_MODEL_DIR}") 62 | # 2. determine correct inference handler 63 | inference_handler = get_inference_handler_either_custom_or_default_handler( 64 | HF_MODEL_DIR, task=HF_TASK 65 | ) 66 | logger.info("Model initialized successfully") 67 | 68 | 69 | async def health(request): 70 | return PlainTextResponse("Ok") 71 | 72 | 73 | # Report Prometheus metrics 74 | # inf_batch_current_size: Current number of requests being processed 75 | # inf_queue_size: Number of requests waiting in the queue 76 | async def metrics(request): 77 | batch_current_size = MAX_CONCURRENT_THREADS - MAX_THREADS_GUARD.value 78 | queue_size = MAX_THREADS_GUARD.statistics().tasks_waiting 79 | return PlainTextResponse( 80 | f"inf_batch_current_size {batch_current_size}\n" + 81 | f"inf_queue_size {queue_size}\n" 82 | ) 83 | 84 | 85 | async def predict(request): 86 | try: 87 | # extracts content from request 88 | content_type = request.headers.get("content-Type", None) 89 | # try to deserialize payload 90 | deserialized_body = ContentType.get_deserializer(content_type).deserialize( 91 | await request.body() 92 | ) 93 | # checks if input schema is correct 94 | if "inputs" not in deserialized_body and "instances" not in deserialized_body: 95 | raise ValueError( 96 | f"Body needs to provide a inputs key, received: {orjson.dumps(deserialized_body)}" 97 | ) 98 | 99 | # Decode base64 audio inputs before running inference 100 | if "parameters" in deserialized_body and HF_TASK in { 101 | "automatic-speech-recognition", 102 | "audio-classification", 103 | }: 104 | # Be more strict on base64 decoding, the provided string should valid base64 encoded data 105 | deserialized_body["inputs"] = base64.b64decode( 106 | deserialized_body["inputs"], validate=True 107 | ) 108 | 109 | # check for query parameter and add them to the body 110 | if request.query_params and "parameters" not in deserialized_body: 111 | deserialized_body["parameters"] = convert_params_to_int_or_bool( 112 | dict(request.query_params) 113 | ) 114 | 115 | # tracks request time 116 | start_time = perf_counter() 117 | # run async not blocking call 118 | pred = await async_handler_call(inference_handler, deserialized_body) 119 | # log request time 120 | logger.info( 121 | f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms" 122 | ) 123 | 124 | # response extracts content from request 125 | accept = request.headers.get("accept", None) 126 | if accept is None or accept == "*/*": 127 | accept = "application/json" 128 | # deserialized and resonds with json 129 | serialized_response_body = ContentType.get_serializer(accept).serialize( 130 | pred, accept 131 | ) 132 | return Response(serialized_response_body, media_type=accept) 133 | except Exception as e: 134 | logger.error(e) 135 | return Response( 136 | Jsoner.serialize({"error": str(e)}), 137 | status_code=400, 138 | media_type="application/json", 139 | ) 140 | 141 | 142 | # Create app based on which cloud environment is used 143 | if os.getenv("AIP_MODE", None) == "PREDICTION": 144 | logger.info("Running in Vertex AI environment") 145 | # extract routes from environment variables 146 | _predict_route = os.getenv("AIP_PREDICT_ROUTE", None) 147 | _health_route = os.getenv("AIP_HEALTH_ROUTE", None) 148 | if _predict_route is None or _health_route is None: 149 | raise ValueError( 150 | "AIP_PREDICT_ROUTE and AIP_HEALTH_ROUTE need to be set in Vertex AI environment" 151 | ) 152 | 153 | app = Starlette( 154 | debug=False, 155 | routes=[ 156 | Route(_health_route, health, methods=["GET"]), 157 | Route(_predict_route, predict, methods=["POST"]), 158 | ], 159 | on_startup=[prepare_model_artifacts], 160 | ) 161 | else: 162 | app = Starlette( 163 | debug=False, 164 | routes=[ 165 | Route("/", health, methods=["GET"]), 166 | Route("/health", health, methods=["GET"]), 167 | Route("/", predict, methods=["POST"]), 168 | Route("/predict", predict, methods=["POST"]), 169 | Route("/metrics", metrics, methods=["GET"]), 170 | ], 171 | on_startup=[prepare_model_artifacts], 172 | ) 173 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Dict, Literal, Optional, Union 4 | 5 | from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE 6 | from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS 7 | from huggingface_inference_toolkit.utils import ( 8 | check_and_register_custom_pipeline_from_directory, 9 | get_pipeline, 10 | ) 11 | 12 | 13 | class HuggingFaceHandler: 14 | """ 15 | A Default Hugging Face Inference Handler which works with all 16 | Transformers, Diffusers, Sentence Transformers and Optimum pipelines. 17 | """ 18 | 19 | def __init__( 20 | self, model_dir: Union[str, Path], task: Union[str, None] = None, framework: Literal["pt"] = "pt" 21 | ) -> None: 22 | self.pipeline = get_pipeline( 23 | model_dir=model_dir, # type: ignore 24 | task=task, # type: ignore 25 | framework=framework, 26 | trust_remote_code=HF_TRUST_REMOTE_CODE, 27 | ) 28 | 29 | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: 30 | """ 31 | Handles an inference request with input data and makes a prediction. 32 | Args: 33 | :data: (obj): the raw request body data. 34 | :return: prediction output 35 | """ 36 | inputs = data.pop("inputs", data) 37 | parameters = data.pop("parameters", {}) 38 | 39 | # diffusers and sentence transformers pipelines do not have the `task` arg 40 | if not hasattr(self.pipeline, "task"): 41 | # sentence transformers parameters not supported yet 42 | if any(isinstance(self.pipeline, v) for v in SENTENCE_TRANSFORMERS_TASKS.values()): 43 | return ( # type: ignore 44 | self.pipeline(**inputs) if isinstance(inputs, dict) else self.pipeline(inputs) 45 | ) 46 | # diffusers does support kwargs 47 | return ( # type: ignore 48 | self.pipeline(**inputs, **parameters) 49 | if isinstance(inputs, dict) 50 | else self.pipeline(inputs, **parameters) 51 | ) 52 | 53 | if self.pipeline.task == "question-answering": 54 | if not isinstance(inputs, dict): 55 | raise ValueError(f"inputs must be a dict, but a `{type(inputs)}` was provided instead.") 56 | if not all(k in inputs for k in {"question", "context"}): 57 | raise ValueError( 58 | f"{self.pipeline.task} expects `inputs` to be a dict containing both `question` and " 59 | "`context` as the keys, both of them being either a `str` or a `List[str]`." 60 | ) 61 | 62 | if self.pipeline.task == "table-question-answering": 63 | if not isinstance(inputs, dict): 64 | raise ValueError(f"inputs must be a dict, but a `{type(inputs)}` was provided instead.") 65 | if "question" in inputs: 66 | inputs["query"] = inputs.pop("question") 67 | if not all(k in inputs for k in {"table", "query"}): 68 | raise ValueError( 69 | f"{self.pipeline.task} expects `inputs` to be a dict containing the keys `table` and " 70 | "either `question` or `query`." 71 | ) 72 | 73 | if self.pipeline.task.__contains__("translation") or self.pipeline.task in { 74 | "text-generation", 75 | "image-to-text", 76 | "automatic-speech-recognition", 77 | "text-to-audio", 78 | "text-to-speech", 79 | }: 80 | # `generate_kwargs` needs to be a dict, `generation_parameters` is here for forward compatibility 81 | if "generation_parameters" in parameters: 82 | parameters["generate_kwargs"] = parameters.pop("generation_parameters") 83 | 84 | if self.pipeline.task.__contains__("translation") or self.pipeline.task in {"text-generation"}: 85 | # flatten the values of `generate_kwargs` as it's not supported as is, but via top-level parameters 86 | generate_kwargs = parameters.pop("generate_kwargs", {}) 87 | for key, value in generate_kwargs.items(): 88 | parameters[key] = value 89 | 90 | if self.pipeline.task.__contains__("zero-shot-classification"): 91 | if "candidateLabels" in parameters: 92 | parameters["candidate_labels"] = parameters.pop("candidateLabels") 93 | if not isinstance(inputs, dict): 94 | inputs = {"sequences": inputs} 95 | if "text" in inputs: 96 | inputs["sequences"] = inputs.pop("text") 97 | if not all(k in inputs for k in {"sequences"}) or not all(k in parameters for k in {"candidate_labels"}): 98 | raise ValueError( 99 | f"{self.pipeline.task} expects `inputs` to be either a string or a dict containing the " 100 | "key `text` or `sequences`, and `parameters` to be a dict containing either `candidate_labels` " 101 | "or `candidateLabels`." 102 | ) 103 | 104 | return ( 105 | self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else self.pipeline(inputs, **parameters) # type: ignore 106 | ) 107 | 108 | 109 | class VertexAIHandler(HuggingFaceHandler): 110 | """ 111 | A Default Vertex AI Hugging Face Inference Handler which abstracts the 112 | Vertex AI specific logic for inference. 113 | """ 114 | 115 | def __init__( 116 | self, model_dir: Union[str, Path], task: Union[str, None] = None, framework: Literal["pt"] = "pt" 117 | ) -> None: 118 | super().__init__(model_dir=model_dir, task=task, framework=framework) 119 | 120 | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: 121 | """ 122 | Handles an inference request with input data and makes a prediction. 123 | Args: 124 | :data: (obj): the raw request body data. 125 | :return: prediction output 126 | """ 127 | if "instances" not in data: 128 | raise ValueError("The request body must contain a key 'instances' with a list of instances.") 129 | parameters = data.pop("parameters", {}) 130 | 131 | predictions = [] 132 | # iterate over all instances and make predictions 133 | for inputs in data["instances"]: 134 | payload = {"inputs": inputs, "parameters": parameters} 135 | predictions.append(super().__call__(payload)) 136 | 137 | # return predictions 138 | return {"predictions": predictions} 139 | 140 | 141 | def get_inference_handler_either_custom_or_default_handler(model_dir: Path, task: Optional[str] = None) -> Any: 142 | """ 143 | Returns the appropriate inference handler based on the given model directory and task. 144 | 145 | Args: 146 | model_dir (Path): The directory path where the model is stored. 147 | task (Optional[str]): The task for which the inference handler is required. Defaults to None. 148 | 149 | Returns: 150 | InferenceHandler: The appropriate inference handler based on the given model directory and task. 151 | """ 152 | custom_pipeline = check_and_register_custom_pipeline_from_directory(model_dir) 153 | if custom_pipeline is not None: 154 | return custom_pipeline 155 | 156 | if os.environ.get("AIP_MODE", None) == "PREDICTION": 157 | return VertexAIHandler(model_dir=model_dir, task=task) 158 | 159 | return HuggingFaceHandler(model_dir=model_dir, task=task) 160 | -------------------------------------------------------------------------------- /tests/unit/test_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import tempfile 4 | from pathlib import Path 5 | 6 | from transformers.file_utils import is_torch_available 7 | from transformers.testing_utils import require_tf, require_torch, slow 8 | 9 | from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler 10 | from huggingface_inference_toolkit.utils import ( 11 | _get_framework, 12 | _is_gpu_available, 13 | _load_repository_from_hf, 14 | check_and_register_custom_pipeline_from_directory, 15 | get_pipeline, 16 | ) 17 | 18 | TASK_MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" 19 | 20 | 21 | def test_load_revision_repository_from_hf(): 22 | MODEL = "lysandre/tiny-bert-random" 23 | REVISION = "eb4c77816edd604d0318f8e748a1c606a2888493" 24 | with tempfile.TemporaryDirectory() as tmpdirname: 25 | storage_folder = _load_repository_from_hf(MODEL, tmpdirname, revision=REVISION) 26 | # folder contains all config files and pytorch_model.bin 27 | folder_contents = os.listdir(storage_folder) 28 | # revision doesn't have tokenizer 29 | assert "tokenizer_config.json" not in folder_contents 30 | 31 | 32 | @require_tf 33 | def test_load_tensorflow_repository_from_hf(): 34 | MODEL = "lysandre/tiny-bert-random" 35 | with tempfile.TemporaryDirectory() as tmpdirname: 36 | tf_tmp = Path(tmpdirname).joinpath("tf") 37 | tf_tmp.mkdir(parents=True, exist_ok=True) 38 | 39 | storage_folder = _load_repository_from_hf(MODEL, tf_tmp, framework="tensorflow") 40 | # folder contains all config files and pytorch_model.bin 41 | folder_contents = os.listdir(storage_folder) 42 | assert "pytorch_model.bin" not in folder_contents 43 | # filter framework 44 | assert "tf_model.h5" in folder_contents 45 | # revision doesn't have tokenizer 46 | assert "tokenizer_config.json" in folder_contents 47 | 48 | 49 | def test_load_onnx_repository_from_hf(): 50 | MODEL = "philschmid/distilbert-onnx-banking77" 51 | with tempfile.TemporaryDirectory() as tmpdirname: 52 | ox_tmp = Path(tmpdirname).joinpath("onnx") 53 | ox_tmp.mkdir(parents=True, exist_ok=True) 54 | 55 | storage_folder = _load_repository_from_hf(MODEL, ox_tmp, framework="onnx") 56 | # folder contains all config files and pytorch_model.bin 57 | folder_contents = os.listdir(storage_folder) 58 | assert "pytorch_model.bin" not in folder_contents 59 | # filter framework 60 | assert "tf_model.h5" not in folder_contents 61 | # onnx model 62 | assert "model.onnx" in folder_contents 63 | # custom pipeline 64 | assert "handler.py" in folder_contents 65 | # revision doesn't have tokenizer 66 | assert "tokenizer_config.json" in folder_contents 67 | 68 | 69 | @require_torch 70 | def test_load_pytorch_repository_from_hf(): 71 | MODEL = "lysandre/tiny-bert-random" 72 | with tempfile.TemporaryDirectory() as tmpdirname: 73 | pt_tmp = Path(tmpdirname).joinpath("pt") 74 | pt_tmp.mkdir(parents=True, exist_ok=True) 75 | 76 | storage_folder = _load_repository_from_hf(MODEL, pt_tmp, framework="pytorch") 77 | # folder contains all config files and pytorch_model.bin 78 | folder_contents = os.listdir(storage_folder) 79 | assert "pytorch_model.bin" in folder_contents 80 | # filter framework 81 | assert "tf_model.h5" not in folder_contents 82 | # revision doesn't have tokenizer 83 | assert "tokenizer_config.json" in folder_contents 84 | 85 | 86 | @slow 87 | def test_gpu_available(): 88 | device = _is_gpu_available() 89 | assert device is True 90 | 91 | 92 | @require_torch 93 | def test_get_framework_pytorch(): 94 | framework = _get_framework() 95 | assert framework == "pytorch" 96 | 97 | 98 | @require_tf 99 | def test_get_framework_tensorflow(): 100 | framework = _get_framework() 101 | if is_torch_available(): 102 | assert framework == "pytorch" 103 | else: 104 | assert framework == "tensorflow" 105 | 106 | 107 | @require_torch 108 | def test_get_pipeline(): 109 | MODEL = "hf-internal-testing/tiny-random-BertForSequenceClassification" 110 | TASK = "text-classification" 111 | with tempfile.TemporaryDirectory() as tmpdirname: 112 | storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") 113 | pipe = get_pipeline( 114 | task = TASK, 115 | model_dir = storage_dir.as_posix(), 116 | ) 117 | res = pipe("Life is good, Life is bad") 118 | assert "score" in res[0] 119 | 120 | 121 | @require_torch 122 | def test_whisper_long_audio(cache_test_dir): 123 | with tempfile.TemporaryDirectory() as tmpdirname: 124 | storage_dir = _load_repository_from_hf( 125 | repository_id = "openai/whisper-tiny", 126 | target_dir = tmpdirname, 127 | ) 128 | logging.info(f"Temp dir: {tmpdirname}") 129 | logging.info(f"POSIX Path: {storage_dir.as_posix()}") 130 | logging.info(f"Contents: {os.listdir(tmpdirname)}") 131 | pipe = get_pipeline( 132 | task = "automatic-speech-recognition", 133 | model_dir = storage_dir.as_posix(), 134 | ) 135 | res = pipe(f"{cache_test_dir}/resources/audio/long_sample.mp3") 136 | 137 | assert len(res["text"]) > 700 138 | 139 | @require_torch 140 | def test_wrapped_pipeline(): 141 | with tempfile.TemporaryDirectory() as tmpdirname: 142 | storage_dir = _load_repository_from_hf( 143 | repository_id = "microsoft/DialoGPT-small", 144 | target_dir = tmpdirname, 145 | framework="pytorch" 146 | ) 147 | conv_pipe = get_pipeline("conversational", storage_dir.as_posix()) 148 | data = [ 149 | { 150 | "role": "user", 151 | "content": "Which movie is the best ?" 152 | }, 153 | { 154 | "role": "assistant", 155 | "content": "It's Die Hard for sure." 156 | }, 157 | { 158 | "role": "user", 159 | "content": "Can you explain why?" 160 | } 161 | ] 162 | res = conv_pipe(data, max_new_tokens = 100) 163 | logging.info(f"Response: {res}") 164 | message = res[0]["generated_text"][-1] 165 | assert message["role"] == "assistant" 166 | 167 | 168 | def test_local_custom_pipeline(cache_test_dir): 169 | model_dir = f"{cache_test_dir}/resources/custom_handler" 170 | pipeline = check_and_register_custom_pipeline_from_directory(model_dir) 171 | payload = "test" 172 | assert pipeline.path == model_dir 173 | assert pipeline(payload) == payload[::-1] 174 | 175 | 176 | def test_remote_custom_pipeline(): 177 | with tempfile.TemporaryDirectory() as tmpdirname: 178 | storage_dir = _load_repository_from_hf( 179 | "philschmid/custom-pipeline-text-classification", 180 | tmpdirname, 181 | framework="pytorch" 182 | ) 183 | pipeline = check_and_register_custom_pipeline_from_directory(str(storage_dir)) 184 | payload = "test" 185 | assert pipeline.path == str(storage_dir) 186 | assert pipeline(payload) == payload 187 | 188 | 189 | def test_get_inference_handler_either_custom_or_default_pipeline(): 190 | with tempfile.TemporaryDirectory() as tmpdirname: 191 | storage_dir = _load_repository_from_hf( 192 | "philschmid/custom-pipeline-text-classification", 193 | tmpdirname, 194 | framework="pytorch" 195 | ) 196 | pipeline = get_inference_handler_either_custom_or_default_handler(str(storage_dir)) 197 | payload = "test" 198 | assert pipeline.path == str(storage_dir) 199 | assert pipeline(payload) == payload 200 | 201 | with tempfile.TemporaryDirectory() as tmpdirname: 202 | MODEL = "lysandre/tiny-bert-random" 203 | TASK = "text-classification" 204 | pipeline = get_inference_handler_either_custom_or_default_handler(MODEL, TASK) 205 | res = pipeline({"inputs": "Life is good, Life is bad"}) 206 | assert "score" in res[0] 207 | -------------------------------------------------------------------------------- /scripts/inf2_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | This script is here to specify all missing environment variables that would be required to run some encoder models on 5 | inferentia2. 6 | """ 7 | 8 | import argparse 9 | import logging 10 | import os 11 | import sys 12 | from typing import Any, Dict, List, Optional 13 | 14 | from huggingface_hub import constants 15 | from transformers import AutoConfig 16 | 17 | from optimum.neuron.utils import get_hub_cached_entries 18 | from optimum.neuron.utils.version_utils import get_neuronxcc_version 19 | 20 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', force=True) 21 | logger = logging.getLogger(__name__) 22 | 23 | env_config_peering = [ 24 | ("HF_BATCH_SIZE", "static_batch_size"), 25 | ("HF_OPTIMUM_SEQUENCE_LENGTH", "static_sequence_length"), 26 | ] 27 | 28 | # By the end of this script all env vars should be specified properly 29 | env_vars = list(map(lambda x: x[0], env_config_peering)) 30 | 31 | # Currently not used for encoder models 32 | # available_cores = get_available_cores() 33 | 34 | neuronxcc_version = get_neuronxcc_version() 35 | 36 | 37 | def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace: 38 | parser = argparse.ArgumentParser() 39 | if not argv: 40 | argv = sys.argv 41 | # All these are params passed to tgi and intercepted here 42 | parser.add_argument( 43 | "--batch-size", 44 | type=int, 45 | default=os.getenv("HF_BATCH_SIZE", os.getenv("BATCH_SIZE", 0)), 46 | ) 47 | parser.add_argument( 48 | "--sequence-length", type=int, 49 | default=os.getenv("HF_OPTIMUM_SEQUENCE_LENGTH", 50 | os.getenv("SEQUENCE_LENGTH", 0)) 51 | ) 52 | 53 | parser.add_argument("--model-id", type=str, default=os.getenv("HF_MODEL_ID", os.getenv("HF_MODEL_DIR"))) 54 | parser.add_argument("--revision", type=str, default=os.getenv("REVISION")) 55 | 56 | args = parser.parse_known_args(argv)[0] 57 | 58 | if not args.model_id: 59 | raise Exception( 60 | "No model id provided ! Either specify it using --model-id cmdline or MODEL_ID env var" 61 | ) 62 | 63 | # Override env with cmdline params 64 | os.environ["MODEL_ID"] = args.model_id 65 | 66 | # Set all tgi router and tgi server values to consistent values as early as possible 67 | # from the order of the parser defaults, the tgi router value can override the tgi server ones 68 | if args.batch_size > 0: 69 | os.environ["HF_BATCH_SIZE"] = str(args.batch_size) 70 | 71 | if args.sequence_length > 0: 72 | os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = str(args.sequence_length) 73 | 74 | if args.revision: 75 | os.environ["REVISION"] = str(args.revision) 76 | 77 | return args 78 | 79 | 80 | def neuron_config_to_env(neuron_config): 81 | with open(os.environ["ENV_FILEPATH"], "w") as f: 82 | for env_var, config_key in env_config_peering: 83 | f.write("export {}={}\n".format(env_var, neuron_config[config_key])) 84 | 85 | 86 | def sort_neuron_configs(dictionary): 87 | return -dictionary["static_batch_size"] 88 | 89 | 90 | def lookup_compatible_cached_model( 91 | model_id: str, revision: Optional[str] 92 | ) -> Optional[Dict[str, Any]]: 93 | # Reuse the same mechanic as the one in use to configure the tgi server part 94 | # The only difference here is that we stay as flexible as possible on the compatibility part 95 | entries = get_hub_cached_entries(model_id, "inference") 96 | 97 | logger.debug( 98 | "Found %d cached entries for model %s, revision %s", 99 | len(entries), 100 | model_id, 101 | revision, 102 | ) 103 | 104 | all_compatible = [] 105 | for entry in entries: 106 | if check_env_and_neuron_config_compatibility( 107 | entry, check_compiler_version=True 108 | ): 109 | all_compatible.append(entry) 110 | 111 | if not all_compatible: 112 | logger.debug( 113 | "No compatible cached entry found for model %s, env %s, neuronxcc version %s", 114 | model_id, 115 | get_env_dict(), 116 | neuronxcc_version, 117 | ) 118 | return None 119 | 120 | logger.info("%d compatible neuron cached models found", len(all_compatible)) 121 | 122 | all_compatible = sorted(all_compatible, key=sort_neuron_configs) 123 | 124 | entry = all_compatible[0] 125 | 126 | logger.info("Selected entry %s", entry) 127 | 128 | return entry 129 | 130 | 131 | def check_env_and_neuron_config_compatibility( 132 | neuron_config: Dict[str, Any], check_compiler_version: bool 133 | ) -> bool: 134 | logger.debug( 135 | "Checking the provided neuron config %s is compatible with the local setup and provided environment", 136 | neuron_config, 137 | ) 138 | 139 | # Local setup compat checks 140 | # if neuron_config["num_cores"] > available_cores: 141 | # logger.debug( 142 | # "Not enough neuron cores available to run the provided neuron config" 143 | # ) 144 | # return False 145 | 146 | if ( 147 | check_compiler_version 148 | and neuron_config["compiler_version"] != neuronxcc_version 149 | ): 150 | logger.debug( 151 | "Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)", 152 | neuronxcc_version, 153 | neuron_config["compiler_version"], 154 | ) 155 | return False 156 | 157 | for env_var, config_key in env_config_peering: 158 | try: 159 | neuron_config_value = str(neuron_config[config_key]) 160 | except KeyError: 161 | logger.debug("No key %s found in neuron config %s", config_key, neuron_config) 162 | return False 163 | env_value = os.getenv(env_var, str(neuron_config_value)) 164 | if env_value != neuron_config_value: 165 | logger.debug( 166 | "The provided env var '%s' and the neuron config '%s' param differ (%s != %s)", 167 | env_var, 168 | config_key, 169 | env_value, 170 | neuron_config_value, 171 | ) 172 | return False 173 | 174 | return True 175 | 176 | 177 | def get_env_dict() -> Dict[str, str]: 178 | d = {} 179 | for k in env_vars: 180 | d[k] = os.getenv(k) 181 | return d 182 | 183 | 184 | def main(): 185 | """ 186 | This script determines proper default TGI env variables for the neuron precompiled models to 187 | work properly 188 | :return: 189 | """ 190 | args = parse_cmdline_and_set_env() 191 | 192 | for env_var in env_vars: 193 | if not os.getenv(env_var): 194 | break 195 | else: 196 | logger.info( 197 | "All env vars %s already set, skipping, user know what they are doing", 198 | env_vars, 199 | ) 200 | sys.exit(0) 201 | 202 | cache_dir = constants.HF_HUB_CACHE 203 | 204 | logger.info("Cache dir %s, model %s", cache_dir, args.model_id) 205 | 206 | config = AutoConfig.from_pretrained(args.model_id, revision=args.revision) 207 | neuron_config = getattr(config, "neuron", None) 208 | if neuron_config is not None: 209 | compatible = check_env_and_neuron_config_compatibility( 210 | neuron_config, check_compiler_version=False 211 | ) 212 | if not compatible: 213 | env_dict = get_env_dict() 214 | msg = ( 215 | "Invalid neuron config and env. Config {}, env {}, neuronxcc version {}" 216 | ).format(neuron_config, env_dict, neuronxcc_version) 217 | logger.error(msg) 218 | raise Exception(msg) 219 | else: 220 | neuron_config = lookup_compatible_cached_model(args.model_id, args.revision) 221 | 222 | if not neuron_config: 223 | neuron_config = {'static_batch_size': 1, 'static_sequence_length': 128} 224 | msg = ( 225 | "No compatible neuron config found. Provided env {}, neuronxcc version {}. Falling back to default" 226 | ).format(get_env_dict(), neuronxcc_version, neuron_config) 227 | logger.info(msg) 228 | 229 | logger.info("Final neuron config %s", neuron_config) 230 | 231 | neuron_config_to_env(neuron_config) 232 | 233 | 234 | if __name__ == "__main__": 235 | main() 236 | -------------------------------------------------------------------------------- /src/huggingface_inference_toolkit/utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import sys 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | 6 | from huggingface_hub import HfApi, login, snapshot_download 7 | from transformers import WhisperForConditionalGeneration, pipeline 8 | from transformers.file_utils import is_tf_available, is_torch_available 9 | from transformers.pipelines import Pipeline 10 | 11 | from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME 12 | from huggingface_inference_toolkit.diffusers_utils import ( 13 | get_diffusers_pipeline, 14 | is_diffusers_available, 15 | ) 16 | from huggingface_inference_toolkit.logging import logger 17 | from huggingface_inference_toolkit.optimum_utils import ( 18 | get_optimum_neuron_pipeline, 19 | is_optimum_neuron_available, 20 | ) 21 | from huggingface_inference_toolkit.sentence_transformers_utils import ( 22 | get_sentence_transformers_pipeline, 23 | is_sentence_transformers_available, 24 | ) 25 | 26 | if is_tf_available(): 27 | import tensorflow as tf 28 | 29 | if is_torch_available(): 30 | import torch 31 | 32 | _optimum_available = importlib.util.find_spec("optimum") is not None 33 | 34 | 35 | def is_optimum_available(): 36 | return False 37 | # TODO: change when supported 38 | # return _optimum_available 39 | 40 | 41 | framework2weight = { 42 | "pytorch": "pytorch*", 43 | "tensorflow": "tf*", 44 | "tf": "tf*", 45 | "pt": "pytorch*", 46 | "flax": "flax*", 47 | "rust": "rust*", 48 | "onnx": "*onnx*", 49 | "safetensors": "*safetensors", 50 | "coreml": "*mlmodel", 51 | "tflite": "*tflite", 52 | "savedmodel": "*tar.gz", 53 | "openvino": "*openvino*", 54 | "ckpt": "*ckpt", 55 | } 56 | 57 | 58 | def create_artifact_filter(framework): 59 | """ 60 | Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading 61 | """ 62 | ignore_regex_list = list(set(framework2weight.values())) 63 | 64 | pattern = framework2weight.get(framework, None) 65 | if pattern in ignore_regex_list: 66 | ignore_regex_list.remove(pattern) 67 | return ignore_regex_list 68 | else: 69 | return [] 70 | 71 | 72 | def _is_gpu_available(): 73 | """ 74 | checks if a gpu is available. 75 | """ 76 | if is_tf_available(): 77 | return True if len(tf.config.list_physical_devices("GPU")) > 0 else False 78 | elif is_torch_available(): 79 | return torch.cuda.is_available() 80 | else: 81 | raise RuntimeError( 82 | "At least one of TensorFlow 2.0 or PyTorch should be installed. " 83 | "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " 84 | "To install PyTorch, read the instructions at https://pytorch.org/." 85 | ) 86 | 87 | 88 | def _get_framework(): 89 | """ 90 | extracts which DL framework is used for inference, if both are installed use pytorch 91 | """ 92 | 93 | if is_torch_available(): 94 | return "pytorch" 95 | elif is_tf_available(): 96 | return "tensorflow" 97 | else: 98 | raise RuntimeError( 99 | "At least one of TensorFlow 2.0 or PyTorch should be installed. " 100 | "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " 101 | "To install PyTorch, read the instructions at https://pytorch.org/." 102 | ) 103 | 104 | 105 | def _load_repository_from_hf( 106 | repository_id: Optional[str] = None, 107 | target_dir: Optional[Union[str, Path]] = None, 108 | framework: Optional[str] = None, 109 | revision: Optional[str] = None, 110 | hf_hub_token: Optional[str] = None, 111 | ): 112 | """ 113 | Load a model from huggingface hub. 114 | """ 115 | 116 | if hf_hub_token is not None: 117 | login(token=hf_hub_token) 118 | 119 | if framework is None: 120 | framework = _get_framework() 121 | 122 | if isinstance(target_dir, str): 123 | target_dir = Path(target_dir) 124 | 125 | # create workdir 126 | if not target_dir.exists(): 127 | target_dir.mkdir(parents=True) 128 | 129 | # check if safetensors weights are available 130 | if framework == "pytorch": 131 | files = HfApi().model_info(repository_id).siblings 132 | if any(f.rfilename.endswith("safetensors") for f in files): 133 | framework = "safetensors" 134 | 135 | # create regex to only include the framework specific weights 136 | ignore_regex = create_artifact_filter(framework) 137 | logger.info(f"Ignore regex pattern for files, which are not downloaded: {', '.join(ignore_regex)}") 138 | 139 | # Download the repository to the workdir and filter out non-framework 140 | # specific weights 141 | snapshot_download( 142 | repo_id=repository_id, 143 | revision=revision, 144 | local_dir=str(target_dir), 145 | local_dir_use_symlinks=False, 146 | ignore_patterns=ignore_regex, 147 | ) 148 | 149 | return target_dir 150 | 151 | 152 | def check_and_register_custom_pipeline_from_directory(model_dir): 153 | """ 154 | Checks if a custom pipeline is available and registers it if so. 155 | """ 156 | # path to custom handler 157 | custom_module = Path(model_dir).joinpath(HF_DEFAULT_PIPELINE_NAME) 158 | legacy_module = Path(model_dir).joinpath("pipeline.py") 159 | if custom_module.is_file(): 160 | logger.info(f"Found custom pipeline at {custom_module}") 161 | spec = importlib.util.spec_from_file_location(HF_MODULE_NAME, custom_module) 162 | if spec: 163 | # add the whole directory to path for submodlues 164 | sys.path.insert(0, model_dir) 165 | # import custom handler 166 | handler = importlib.util.module_from_spec(spec) 167 | sys.modules[HF_MODULE_NAME] = handler 168 | spec.loader.exec_module(handler) 169 | # init custom handler with model_dir 170 | custom_pipeline = handler.EndpointHandler(model_dir) 171 | 172 | elif legacy_module.is_file(): 173 | logger.warning( 174 | """You are using a legacy custom pipeline. 175 | Please update to the new format. 176 | See documentation for more information.""" 177 | ) 178 | spec = importlib.util.spec_from_file_location("pipeline.PreTrainedPipeline", legacy_module) 179 | if spec: 180 | # add the whole directory to path for submodlues 181 | sys.path.insert(0, model_dir) 182 | # import custom handler 183 | pipeline = importlib.util.module_from_spec(spec) 184 | sys.modules["pipeline.PreTrainedPipeline"] = pipeline 185 | spec.loader.exec_module(pipeline) 186 | # init custom handler with model_dir 187 | custom_pipeline = pipeline.PreTrainedPipeline(model_dir) 188 | else: 189 | logger.info(f"No custom pipeline found at {custom_module}") 190 | custom_pipeline = None 191 | return custom_pipeline 192 | 193 | 194 | def get_device(): 195 | """ 196 | The get device function will return the device for the DL Framework. 197 | """ 198 | gpu = _is_gpu_available() 199 | 200 | if gpu: 201 | return 0 202 | else: 203 | return -1 204 | 205 | 206 | def get_pipeline( 207 | task: Union[str, None], 208 | model_dir: Path, 209 | **kwargs, 210 | ) -> Pipeline: 211 | """ 212 | create pipeline class for a specific task based on local saved model 213 | """ 214 | if task is None: 215 | raise EnvironmentError( 216 | "The task for this model is not set: Please set one: https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined" 217 | ) 218 | 219 | if task == "conversational": 220 | task = "text-generation" 221 | 222 | if is_optimum_neuron_available(): 223 | logger.info("Using device Neuron") 224 | return get_optimum_neuron_pipeline(task=task, model_dir=model_dir) 225 | 226 | device = get_device() 227 | logger.info(f"Using device {'GPU' if device == 0 else 'CPU'}") 228 | 229 | # define tokenizer or feature extractor as kwargs to load it the pipeline 230 | # correctly 231 | if task in { 232 | "automatic-speech-recognition", 233 | "image-segmentation", 234 | "image-classification", 235 | "audio-classification", 236 | "object-detection", 237 | "zero-shot-image-classification", 238 | }: 239 | kwargs["feature_extractor"] = model_dir 240 | elif task not in {"image-text-to-text", "image-to-text", "text-to-image"}: 241 | kwargs["tokenizer"] = model_dir 242 | 243 | if is_sentence_transformers_available() and task in [ 244 | "sentence-similarity", 245 | "sentence-embeddings", 246 | "sentence-ranking", 247 | "text-ranking", 248 | ]: 249 | if task == "text-ranking": 250 | task = "sentence-ranking" 251 | hf_pipeline = get_sentence_transformers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs) 252 | elif is_diffusers_available() and task == "text-to-image": 253 | hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs) 254 | else: 255 | hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs) 256 | 257 | if task == "automatic-speech-recognition" and isinstance(hf_pipeline.model, WhisperForConditionalGeneration): 258 | # set chunk length to 30s for whisper to enable long audio files 259 | hf_pipeline._preprocess_params["chunk_length_s"] = 30 260 | hf_pipeline.model.config.forced_decoder_ids = hf_pipeline.tokenizer.get_decoder_prompt_ids( 261 | language="english", task="transcribe" 262 | ) 263 | return hf_pipeline # type: ignore 264 | 265 | 266 | def convert_params_to_int_or_bool(params): 267 | """Converts query params to int or bool if possible""" 268 | for k, v in params.items(): 269 | if v.isnumeric(): 270 | params[k] = int(v) 271 | if v == "false": 272 | params[k] = False 273 | if v == "true": 274 | params[k] = True 275 | return params 276 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Hugging Face Inference Toolkit 4 | 5 | Hugging Face Inference Toolkit is for serving 🤗 Transformers models in containers. This library provides default pre-processing, prediction, and postprocessing for Transformers, diffusers, and Sentence Transformers. It is also possible to define a custom `handler.py` for customization. The Toolkit is built to work with the [Hugging Face Hub](https://huggingface.co/models) and is used as the "default" option in [Inference Endpoints](https://ui.endpoints.huggingface.co/) 6 | 7 | ## 💻 Getting Started with Hugging Face Inference Toolkit 8 | 9 | - Clone the repository `git clone https://github.com/huggingface/huggingface-inference-toolkit` 10 | - Install the dependencies in dev mode `pip install -e ".[torch,st,diffusers,test,quality]"` 11 | - If you develop on AWS Inferentia2 install with `pip install -e ".[inf2,test,quality]" --upgrade` 12 | - If you develop on Google Cloud install with `pip install -e ".[torch,st,diffusers,google,test,quality]"` 13 | - Unit Testing: `make unit-test` 14 | - Integration testing: `make integ-test` 15 | 16 | ### Local run 17 | 18 | ```bash 19 | mkdir tmp2/ 20 | HF_MODEL_ID=hf-internal-testing/tiny-random-distilbert HF_MODEL_DIR=tmp2 HF_TASK=text-classification uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000 21 | ``` 22 | 23 | ### Container 24 | 25 | 1. build the preferred container for either CPU or GPU for PyTorch. 26 | 27 | _CPU Images_ 28 | 29 | ```bash 30 | make inference-pytorch-cpu 31 | ``` 32 | 33 | _GPU Images_ 34 | 35 | ```bash 36 | make inference-pytorch-gpu 37 | ``` 38 | 39 | 2. Run the container and provide either environment variables to the HUB model you want to use or mount a volume to the container, where your model is stored. 40 | 41 | ```bash 42 | docker run -ti -p 5000:5000 -e HF_MODEL_ID=distilbert-base-uncased-distilled-squad -e HF_TASK=question-answering integration-test-pytorch:cpu 43 | docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=nlpconnect/vit-gpt2-image-captioning -e HF_TASK=image-to-text integration-test-pytorch:gpu 44 | docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=echarlaix/tiny-random-stable-diffusion-xl -e HF_TASK=text-to-image integration-test-pytorch:gpu 45 | docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=stabilityai/stable-diffusion-xl-base-1.0 -e HF_TASK=text-to-image integration-test-pytorch:gpu 46 | docker run -ti -p 5000:5000 -e HF_MODEL_DIR=/repository -v $(pwd)/distilbert-base-uncased-emotion:/repository integration-test-pytorch:cpu 47 | ``` 48 | 49 | 3. Send request. The API schema is the same as from the [inference API](https://huggingface.co/docs/api-inference/detailed_parameters) 50 | 51 | ```bash 52 | curl --request POST \ 53 | --url http://localhost:5000 \ 54 | --header 'Content-Type: application/json' \ 55 | --data '{ 56 | "inputs": { 57 | "question": "What is used for inference?", 58 | "context": "My Name is Philipp and I live in Nuremberg. This model is used with sagemaker for inference." 59 | } 60 | }' 61 | ``` 62 | 63 | ### Custom Handler and dependency support 64 | 65 | The Hugging Face Inference Toolkit allows user to provide a custom inference through a `handler.py` file which is located in the repository. 66 | 67 | For an example check [philschmid/custom-pipeline-text-classification](https://huggingface.co/philschmid/custom-pipeline-text-classification): 68 | 69 | ```bash 70 | model.tar.gz/ 71 | |- pytorch_model.bin 72 | |- .... 73 | |- handler.py 74 | |- requirements.txt 75 | ``` 76 | 77 | In this example, `pytroch_model.bin` is the model file saved from training, `handler.py` is the custom inference handler, and `requirements.txt` is a requirements file to add additional dependencies. 78 | The custom module can override the following methods: 79 | 80 | ### Vertex AI Support 81 | 82 | The Hugging Face Inference Toolkit is also supported on Vertex AI, based on [Custom container requirements for prediction](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). [Environment variables set by Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables) are automatically detected and used by the toolkit. 83 | 84 | #### Local run with HF_MODEL_ID and HF_TASK 85 | 86 | Start Hugging Face Inference Toolkit with the following environment variables. 87 | 88 | ```bash 89 | mkdir tmp2/ 90 | AIP_MODE=PREDICTION AIP_PORT=8080 AIP_PREDICT_ROUTE=/pred AIP_HEALTH_ROUTE=/h HF_MODEL_DIR=tmp2 HF_MODEL_ID=distilbert/distilbert-base-uncased-finetuned-sst-2-english HF_TASK=text-classification uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 8080 91 | ``` 92 | 93 | Send request 94 | 95 | ```bash 96 | curl --request POST \ 97 | --url http://localhost:8080/pred \ 98 | --header 'Content-Type: application/json' \ 99 | --data '{ 100 | "instances": ["I love this product", "I hate this product"], 101 | "parameters": { "top_k": 2 } 102 | }' 103 | ``` 104 | 105 | #### Container run with HF_MODEL_ID and HF_TASK 106 | 107 | 1. build the preferred container for either CPU or GPU for PyTorch o. 108 | 109 | ```bash 110 | docker build -t vertex -f dockerfiles/pytorch/Dockerfile -t vertex-test-pytorch:gpu . 111 | ``` 112 | 113 | 2. Run the container and provide either environment variables to the HUB model you want to use or mount a volume to the container, where your model is stored. 114 | 115 | ```bash 116 | docker run -ti -p 8080:8080 -e AIP_MODE=PREDICTION -e AIP_HTTP_PORT=8080 -e AIP_PREDICT_ROUTE=/pred -e AIP_HEALTH_ROUTE=/h -e HF_MODEL_ID=distilbert/distilbert-base-uncased-finetuned-sst-2-english -e HF_TASK=text-classification vertex-test-pytorch:gpu 117 | ``` 118 | 119 | 3. Send request 120 | 121 | ```bash 122 | curl --request POST \ 123 | --url http://localhost:8080/pred \ 124 | --header 'Content-Type: application/json' \ 125 | --data '{ 126 | "instances": ["I love this product", "I hate this product"], 127 | "parameters": { "top_k": 2 } 128 | }' 129 | ``` 130 | 131 | ### AWS Inferentia2 Support 132 | 133 | The Hugging Face Inference Toolkit provides support for deploying Hugging Face on AWS Inferentia2. To deploy a model on Inferentia2 you have 3 options: 134 | 135 | - Provide `HF_MODEL_ID`, the model repo id on huggingface.co which contains the compiled model under `.neuron` format e.g. `optimum/bge-base-en-v1.5-neuronx` 136 | - Provide the `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` environment variables to compile the model on the fly, e.g. `HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128` 137 | - Include `neuron` dictionary in the [config.json](https://huggingface.co/optimum/tiny_random_bert_neuron/blob/main/config.json) file in the model archive, e.g. `neuron: {"static_batch_size": 1, "static_sequence_length": 128}` 138 | 139 | The currently supported tasks can be found [here](https://huggingface.co/docs/optimum-neuron/en/package_reference/supported_models). If you plan to deploy an LLM, we recommend taking a look at [Neuronx TGI](https://huggingface.co/blog/text-generation-inference-on-inferentia2), which is purposly build for LLMs. 140 | 141 | #### Local run with HF_MODEL_ID and HF_TASK 142 | 143 | Start Hugging Face Inference Toolkit with the following environment variables. 144 | 145 | _Note: You need to run this on an Inferentia2 instance._ 146 | 147 | - transformers `text-classification` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` 148 | 149 | ```bash 150 | mkdir tmp2/ 151 | HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" HF_TASK="text-classification" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000 152 | ``` 153 | 154 | - sentence transformers `feature-extraction` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` 155 | 156 | ```bash 157 | HF_MODEL_ID="sentence-transformers/all-MiniLM-L6-v2" HF_TASK="feature-extraction" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000 158 | ``` 159 | 160 | Send request 161 | 162 | ```bash 163 | curl --request POST \ 164 | --url http://localhost:5000 \ 165 | --header 'Content-Type: application/json' \ 166 | --data '{ 167 | "inputs": "Wow, this is such a great product. I love it!" 168 | }' 169 | ``` 170 | 171 | #### Container run with HF_MODEL_ID and HF_TASK 172 | 173 | 1. build the preferred container for either CPU or GPU for PyTorch o. 174 | 175 | ```bash 176 | make inference-pytorch-inf2 177 | ``` 178 | 179 | 2. Run the container and provide either environment variables to the HUB model you want to use or mount a volume to the container, where your model is stored. 180 | 181 | ```bash 182 | docker run -ti -p 5000:5000 -e HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" -e HF_TASK="text-classification" -e HF_OPTIMUM_BATCH_SIZE=1 -e HF_OPTIMUM_SEQUENCE_LENGTH=128 --device=/dev/neuron0 integration-test-pytorch:inf2 183 | ``` 184 | 185 | 3. Send request 186 | 187 | ```bash 188 | curl --request POST \ 189 | --url http://localhost:5000 \ 190 | --header 'Content-Type: application/json' \ 191 | --data '{ 192 | "inputs": "Wow, this is such a great product. I love it!", 193 | "parameters": { "top_k": 2 } 194 | }' 195 | ``` 196 | 197 | --- 198 | 199 | ## 🛠️ Environment variables 200 | 201 | The Hugging Face Inference Toolkit implements various additional environment variables to simplify your deployment experience. A full list of environment variables is given below. All potential environment variables can be found in [const.py](src/huggingface_inference_toolkit/const.py) 202 | 203 | ### `HF_MODEL_DIR` 204 | 205 | The `HF_MODEL_DIR` environment variable defines the directory where your model is stored or will be stored. 206 | If `HF_MODEL_ID` is not set the toolkit expects a model artifact at this directory. This value should be set to the value where you mount your model artifacts. 207 | If `HF_MODEL_ID` is set the toolkit and the directory where `HF_MODEL_DIR` is pointing to is empty. The toolkit will download the model from the Hub to this directory. 208 | 209 | The default value is `/opt/huggingface/model` 210 | 211 | ```bash 212 | HF_MODEL_ID="/opt/mymodel" 213 | ``` 214 | 215 | ### `HF_TASK` 216 | 217 | The `HF_TASK` environment variable defines the task for the used Transformers pipeline or Sentence Transformers. A full list of tasks can be found in [supported & tested task section](#supported--tested-tasks) 218 | 219 | ```bash 220 | HF_TASK="question-answering" 221 | ``` 222 | 223 | ### `HF_MODEL_ID` 224 | 225 | The `HF_MODEL_ID` environment variable defines the model id, which will be automatically loaded from [huggingface.co/models](https://huggingface.co/models) when starting the container. 226 | 227 | ```bash 228 | HF_MODEL_ID="distilbert-base-uncased-finetuned-sst-2-english" 229 | ``` 230 | 231 | ### `HF_REVISION` 232 | 233 | The `HF_REVISION` is an extension to `HF_MODEL_ID` and allows you to define/pin a revision of the model to make sure you always load the same model on your SageMaker Endpoint. 234 | 235 | ```bash 236 | HF_REVISION="03b4d196c19d0a73c7e0322684e97db1ec397613" 237 | ``` 238 | 239 | ### `HF_HUB_TOKEN` 240 | 241 | The `HF_HUB_TOKEN` environment variable defines your Hugging Face authorization token. The `HF_HUB_TOKEN` is used as a HTTP bearer authorization for remote files, like private models. You can find your token at your [settings page](https://huggingface.co/settings/token). 242 | 243 | ```bash 244 | HF_HUB_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX" 245 | ``` 246 | 247 | ### `HF_TRUST_REMOTE_CODE` 248 | 249 | The `HF_TRUST_REMOTE_CODE` environment variable defines whether to trust remote code. This flag is already used for community defined inference code, and is therefore quite representative of the level of confidence you are giving the model providers when loading models from the Hugging Face Hub. The default value is `"0"`; set it to `"1"` to trust remote code. 250 | 251 | ```bash 252 | HF_TRUST_REMOTE_CODE="0" 253 | ``` 254 | 255 | ### `HF_FRAMEWORK` 256 | 257 | The `HF_FRAMEWORK` environment variable defines the base deep learning framework used in the container. This is important when loading large models from the Hugging Face Hub to avoid extra file downloads. 258 | 259 | ```bash 260 | HF_FRAMEWORK="pytorch" 261 | ``` 262 | 263 | #### `HF_OPTIMUM_BATCH_SIZE` 264 | 265 | The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted. 266 | 267 | ```bash 268 | HF_OPTIMUM_BATCH_SIZE="1" 269 | ``` 270 | 271 | #### `HF_OPTIMUM_SEQUENCE_LENGTH` 272 | 273 | The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted. 274 | 275 | ```bash 276 | HF_OPTIMUM_SEQUENCE_LENGTH="128" 277 | ``` 278 | 279 | --- 280 | 281 | ## ⚙ Supported Front-Ends 282 | 283 | - [x] Starlette (HF Endpoints) 284 | - [x] Starlette (Vertex AI) 285 | - [ ] Starlette (Azure ML) 286 | - [ ] Starlette (SageMaker) 287 | 288 | ## 📜 License 289 | 290 | This project is licensed under the Apache-2.0 License. 291 | 292 | -------------------------------------------------------------------------------- /tests/integ/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tests.integ.utils import ( 4 | validate_automatic_speech_recognition, 5 | validate_classification, 6 | validate_conversational, 7 | validate_custom, 8 | validate_feature_extraction, 9 | validate_fill_mask, 10 | validate_image_text_to_text, 11 | validate_ner, 12 | validate_object_detection, 13 | validate_question_answering, 14 | validate_summarization, 15 | validate_text2text_generation, 16 | validate_text_generation, 17 | validate_text_to_image, 18 | validate_translation, 19 | validate_zero_shot_classification, 20 | ) 21 | 22 | task2model = { 23 | "text-classification": { 24 | "pytorch": "hf-internal-testing/tiny-random-distilbert", 25 | "tensorflow": "hf-internal-testing/tiny-random-distilbert", 26 | }, 27 | "zero-shot-classification": { 28 | "pytorch": "hf-internal-testing/tiny-random-bart", 29 | "tensorflow": "typeform/distilbert-base-uncased-mnli", 30 | }, 31 | "feature-extraction": { 32 | "pytorch": "hf-internal-testing/tiny-random-bert", 33 | "tensorflow": "hf-internal-testing/tiny-random-bert", 34 | }, 35 | "token-classification": { 36 | "pytorch": "hf-internal-testing/tiny-random-roberta", 37 | "tensorflow": "hf-internal-testing/tiny-random-roberta", 38 | }, 39 | "question-answering": { 40 | "pytorch": "hf-internal-testing/tiny-random-electra", 41 | "tensorflow": "hf-internal-testing/tiny-random-electra", 42 | }, 43 | "fill-mask": { 44 | "pytorch": "hf-internal-testing/tiny-random-bert", 45 | "tensorflow": "hf-internal-testing/tiny-random-bert", 46 | }, 47 | "summarization": { 48 | "pytorch": "hf-internal-testing/tiny-random-bart", 49 | "tensorflow": "hf-internal-testing/tiny-random-bart", 50 | }, 51 | "translation_xx_to_yy": { 52 | "pytorch": "hf-internal-testing/tiny-random-t5", 53 | "tensorflow": "hf-internal-testing/tiny-random-marian", 54 | }, 55 | "text2text-generation": { 56 | "pytorch": "hf-internal-testing/tiny-random-t5", 57 | "tensorflow": "hf-internal-testing/tiny-random-bart", 58 | }, 59 | "text-generation": { 60 | "pytorch": "hf-internal-testing/tiny-random-gpt2", 61 | "tensorflow": "hf-internal-testing/tiny-random-gpt2", 62 | }, 63 | "image-classification": { 64 | "pytorch": "hf-internal-testing/tiny-random-vit", 65 | "tensorflow": "hf-internal-testing/tiny-random-vit", 66 | }, 67 | "automatic-speech-recognition": { 68 | "pytorch": "hf-internal-testing/tiny-random-Wav2Vec2Model", 69 | "tensorflow": None, 70 | }, 71 | "audio-classification": { 72 | "pytorch": "hf-internal-testing/tiny-random-WavLMModel", 73 | "tensorflow": None, 74 | }, 75 | "object-detection": { 76 | "pytorch": "hustvl/yolos-tiny", 77 | "tensorflow": None, 78 | }, 79 | "zero-shot-image-classification": { 80 | "pytorch": "hf-internal-testing/tiny-random-clip-zero-shot-image-classification", 81 | "tensorflow": "hf-internal-testing/tiny-random-clip-zero-shot-image-classification", 82 | }, 83 | "conversational": { 84 | # "pytorch": "hf-internal-testing/tiny-random-blenderbot-small", 85 | "pytorch": "microsoft/DialoGPT-small", 86 | "tensorflow": None, 87 | }, 88 | "sentence-similarity": { 89 | "pytorch": "sentence-transformers/all-MiniLM-L6-v2", 90 | "tensorflow": None, 91 | }, 92 | "sentence-embeddings": { 93 | "pytorch": "sentence-transformers/all-MiniLM-L6-v2", 94 | "tensorflow": None, 95 | }, 96 | "sentence-ranking": { 97 | "pytorch": "cross-encoder/ms-marco-MiniLM-L-6-v2", 98 | "tensorflow": None, 99 | }, 100 | "text-to-image": { 101 | "pytorch": "hf-internal-testing/tiny-stable-diffusion-torch", 102 | "tensorflow": None, 103 | }, 104 | "table-question-answering": { 105 | "pytorch": "philschmid/tapex-tiny", 106 | "tensorflow": None, 107 | }, 108 | "image-segmentation": { 109 | "pytorch": "hf-internal-testing/tiny-random-beit-pipeline", 110 | "tensorflow": None, 111 | }, 112 | "image-text-to-text": { 113 | "pytorch": "Salesforce/blip-image-captioning-base", 114 | "tensorflow": None, 115 | }, 116 | } 117 | 118 | 119 | task2input = { 120 | "text-classification": {"inputs": "I love you. I like you"}, 121 | "zero-shot-classification": { 122 | "inputs": "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!", 123 | "parameters": {"candidate_labels": ["refund", "legal", "faq"]}, 124 | }, 125 | "feature-extraction": {"inputs": "What is the best book."}, 126 | "token-classification": {"inputs": "My name is Wolfgang and I live in Berlin"}, 127 | "question-answering": { 128 | "inputs": { 129 | "question": "What is used for inference?", 130 | "context": "My Name is Philipp and I live in Nuremberg. This model is used with sagemaker for inference.", 131 | } 132 | }, 133 | "fill-mask": {"inputs": "Paris is the [MASK] of France."}, 134 | "summarization": { 135 | "inputs": "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct." 136 | }, 137 | "translation_xx_to_yy": {"inputs": "My name is Sarah and I live in London"}, 138 | "text2text-generation": { 139 | "inputs": "question: What is 42 context: 42 is the answer to life, the universe and everything." 140 | }, 141 | "text-generation": {"inputs": "My name is philipp and I am"}, 142 | "image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(), 143 | "zero-shot-image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(), 144 | "object-detection": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(), 145 | "image-segmentation": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(), 146 | "automatic-speech-recognition": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(), 147 | "audio-classification": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(), 148 | "table-question-answering": { 149 | "inputs": { 150 | "query": "How many stars does the transformers repository have?", 151 | "table": { 152 | "Repository": ["Transformers", "Datasets", "Tokenizers"], 153 | "Stars": ["36542", "4512", "3934"], 154 | "Contributors": ["651", "77", "34"], 155 | "Programming language": ["Python", "Python", "Rust, Python and NodeJS"], 156 | }, 157 | } 158 | }, 159 | "conversational": { 160 | "inputs": [ 161 | {"role": "user", "content": "Which movie is the best ?"}, 162 | ] 163 | }, 164 | "sentence-similarity": { 165 | "inputs": { 166 | "source_sentence": "Lets create an embedding", 167 | "sentences": ["Lets create an embedding"], 168 | } 169 | }, 170 | "sentence-embeddings": {"inputs": "Lets create an embedding"}, 171 | "sentence-ranking": {"inputs": ["Lets create an embedding", "Lets create an embedding"]}, 172 | "text-to-image": {"inputs": "a man on a horse jumps over a broken down airplane."}, 173 | "custom": {"inputs": "this is a test"}, 174 | "image-text-to-text": { 175 | "inputs": { 176 | "images": "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", 177 | "text": "A photo of", 178 | } 179 | }, 180 | } 181 | 182 | task2output = { 183 | "text-classification": [{"label": "POSITIVE", "score": 0.01}], 184 | "zero-shot-classification": { 185 | "sequence": "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!", 186 | "labels": ["refund", "faq", "legal"], 187 | "scores": [0.96, 0.027, 0.008], 188 | }, 189 | "token-classification": [ 190 | { 191 | "word": "Wolfgang", 192 | "score": 0.99, 193 | "entity": "I-PER", 194 | "index": 4, 195 | "start": 11, 196 | "end": 19, 197 | }, 198 | { 199 | "word": "Berlin", 200 | "score": 0.99, 201 | "entity": "I-LOC", 202 | "index": 9, 203 | "start": 34, 204 | "end": 40, 205 | }, 206 | ], 207 | "question-answering": { 208 | "score": 0.99, 209 | "start": 68, 210 | "end": 77, 211 | "answer": "sagemaker", 212 | }, 213 | "summarization": [{"summary_text": " The A The The ANew York City has been installed in the US."}], 214 | "translation_xx_to_yy": [{"translation_text": "Mein Name ist Sarah und ich lebe in London"}], 215 | "text2text-generation": [{"generated_text": "42 is the answer to life, the universe and everything"}], 216 | "feature-extraction": None, 217 | "fill-mask": None, 218 | "text-generation": None, 219 | "image-classification": [ 220 | {"score": 0.8858247399330139, "label": "tiger, Panthera tigris"}, 221 | {"score": 0.10940514504909515, "label": "tiger cat"}, 222 | { 223 | "score": 0.0006216464680619538, 224 | "label": "jaguar, panther, Panthera onca, Felis onca", 225 | }, 226 | {"score": 0.0004262699221726507, "label": "dhole, Cuon alpinus"}, 227 | { 228 | "score": 0.00030842673731967807, 229 | "label": "lion, king of beasts, Panthera leo", 230 | }, 231 | ], 232 | "zero-shot-image-classification": [ 233 | {"score": 0.8858247399330139, "label": "tiger, Panthera tigris"}, 234 | {"score": 0.10940514504909515, "label": "tiger cat"}, 235 | { 236 | "score": 0.0006216464680619538, 237 | "label": "jaguar, panther, Panthera onca, Felis onca", 238 | }, 239 | {"score": 0.0004262699221726507, "label": "dhole, Cuon alpinus"}, 240 | { 241 | "score": 0.00030842673731967807, 242 | "label": "lion, king of beasts, Panthera leo", 243 | }, 244 | ], 245 | "automatic-speech-recognition": { 246 | "text": "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP OAUDIENCES IN DROFTY SCHOOL ROOMS DAY AFTER DAY FOR A FORT NIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS" 247 | }, 248 | "audio-classification": [ 249 | {"label": "no", "score": 0.5052680969238281}, 250 | {"label": "yes", "score": 0.49473199248313904}, 251 | ], 252 | "object-detection": [{"score": 0.9143241047859192, "label": "cat", "box": {}}], 253 | "image-segmentation": [{"score": 0.9143241047859192, "label": "cat", "mask": {}}], 254 | "table-question-answering": {"answer": "36542"}, 255 | "conversational": [ 256 | {"role": "user", "content": "Which movie is the best ?"}, 257 | {"role": "assistant", "content": "It's Die Hard for sure."}, 258 | ], 259 | "sentence-similarity": {"similarities": ""}, 260 | "sentence-embeddings": {"embeddings": ""}, 261 | "sentence-ranking": {"scores": ""}, 262 | "text-to-image": bytes, 263 | "image-text-to-text": [{"input_text": "A photo of", "generated_text": "..."}], 264 | "custom": {"inputs": "this is a test"}, 265 | } 266 | 267 | 268 | task2validation = { 269 | "text-classification": validate_classification, 270 | "zero-shot-classification": validate_zero_shot_classification, 271 | "zero-shot-image-classification": validate_zero_shot_classification, 272 | "feature-extraction": validate_feature_extraction, 273 | "token-classification": validate_ner, 274 | "question-answering": validate_question_answering, 275 | "fill-mask": validate_fill_mask, 276 | "summarization": validate_summarization, 277 | "translation_xx_to_yy": validate_translation, 278 | "text2text-generation": validate_text2text_generation, 279 | "text-generation": validate_text_generation, 280 | "image-classification": validate_classification, 281 | "automatic-speech-recognition": validate_automatic_speech_recognition, 282 | "audio-classification": validate_classification, 283 | "object-detection": validate_object_detection, 284 | "image-segmentation": validate_object_detection, 285 | "table-question-answering": validate_zero_shot_classification, 286 | "conversational": validate_conversational, 287 | "sentence-similarity": validate_zero_shot_classification, 288 | "sentence-embeddings": validate_zero_shot_classification, 289 | "sentence-ranking": validate_zero_shot_classification, 290 | "text-to-image": validate_text_to_image, 291 | "image-text-to-text": validate_image_text_to_text, 292 | "custom": validate_custom, 293 | } 294 | -------------------------------------------------------------------------------- /tests/integ/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import tempfile 4 | import time 5 | import traceback 6 | 7 | import docker 8 | import pytest 9 | import requests 10 | from docker import DockerClient 11 | from transformers.testing_utils import _run_slow_tests, require_tf, require_torch 12 | 13 | from huggingface_inference_toolkit.utils import _load_repository_from_hf 14 | from tests.integ.config import task2input, task2model, task2output, task2validation 15 | 16 | IS_GPU = _run_slow_tests 17 | DEVICE = "gpu" if IS_GPU else "cpu" 18 | 19 | client = docker.DockerClient(base_url="unix://var/run/docker.sock") 20 | 21 | 22 | def make_sure_other_containers_are_stopped(client: DockerClient, container_name: str): 23 | try: 24 | previous = client.containers.get(container_name) 25 | previous.stop() 26 | previous.remove() 27 | except Exception: 28 | return None 29 | 30 | 31 | # @tenacity.retry( 32 | # retry = tenacity.retry_if_exception(ValueError), 33 | # stop = tenacity.stop_after_attempt(10), 34 | # reraise = True 35 | # ) 36 | def wait_for_container_to_be_ready(base_url, time_between_retries=3, max_retries=30): 37 | retries = 0 38 | error = None 39 | 40 | while retries < max_retries: 41 | time.sleep(time_between_retries) 42 | try: 43 | response = requests.get(f"{base_url}/health") 44 | if response.status_code == 200: 45 | logging.info("Container ready!") 46 | return True 47 | else: 48 | raise ConnectionError(f"Couldn'start container, Error: {response.status_code}") 49 | except Exception as exception: 50 | error = exception 51 | logging.warning(f"Container at {base_url} not ready, trying again...") 52 | retries += 1 53 | 54 | logging.error(f"Unable to start container: {str(error)}") 55 | raise error 56 | 57 | 58 | def verify_task( 59 | # container: DockerClient, 60 | task: str, 61 | port: int = 5000, 62 | ): 63 | BASE_URL = f"http://localhost:{port}" 64 | logging.info(f"Base URL: {BASE_URL}") 65 | logging.info(f"Port: {port}") 66 | input = task2input[task] 67 | 68 | try: 69 | # health check 70 | wait_for_container_to_be_ready(BASE_URL) 71 | if ( 72 | task == "image-classification" 73 | or task == "object-detection" 74 | or task == "image-segmentation" 75 | or task == "zero-shot-image-classification" 76 | ): 77 | prediction = requests.post( 78 | f"{BASE_URL}", 79 | data=task2input[task], 80 | headers={"content-type": "image/x-image"}, 81 | ).json() 82 | elif task == "automatic-speech-recognition" or task == "audio-classification": 83 | prediction = requests.post( 84 | f"{BASE_URL}", 85 | data=task2input[task], 86 | headers={"content-type": "audio/x-audio"}, 87 | ).json() 88 | elif task == "text-to-image": 89 | prediction = requests.post(f"{BASE_URL}", json=input, headers={"accept": "image/png"}).content 90 | else: 91 | prediction = requests.post(f"{BASE_URL}", json=input).json() 92 | 93 | logging.info(f"Input: {input}") 94 | logging.info(f"Prediction: {prediction}") 95 | logging.info(f"Snapshot: {task2output[task]}") 96 | 97 | if task == "conversational": 98 | for message in prediction: 99 | assert "error" not in message.keys() 100 | else: 101 | assert task2validation[task](result=prediction, snapshot=task2output[task]) 102 | except Exception as exception: 103 | logging.error(f"Base URL: {BASE_URL}") 104 | logging.error(f"Task: {task}") 105 | logging.error(f"Input: {input}") 106 | logging.error(f"Error: {str(exception)}") 107 | logging.error(f"Stack: {traceback.format_exc()}") 108 | raise exception 109 | 110 | 111 | @require_torch 112 | @pytest.mark.parametrize( 113 | "task", 114 | [ 115 | # transformers 116 | # TODO: "visual-question-answering" and "zero-shot-image-classification" not supported yet due to multimodality input 117 | "text-classification", 118 | "zero-shot-classification", 119 | "token-classification", 120 | "question-answering", 121 | "fill-mask", 122 | "summarization", 123 | "translation_xx_to_yy", 124 | "text2text-generation", 125 | "text-generation", 126 | "feature-extraction", 127 | "image-classification", 128 | "automatic-speech-recognition", 129 | "audio-classification", 130 | "object-detection", 131 | "image-segmentation", 132 | "table-question-answering", 133 | "conversational", 134 | "image-text-to-text", 135 | # sentence-transformers 136 | "sentence-similarity", 137 | "sentence-embeddings", 138 | "sentence-ranking", 139 | # diffusers 140 | "text-to-image", 141 | ], 142 | ) 143 | def test_pt_container_remote_model(task: str) -> None: 144 | container_name = f"integration-test-{task}" 145 | container_image = f"starlette-transformers:{DEVICE}" 146 | framework = "pytorch" 147 | model = task2model[task][framework] 148 | port = random.randint(5000, 6000) 149 | device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] 150 | 151 | make_sure_other_containers_are_stopped(client, container_name) 152 | container = client.containers.run( 153 | container_image, 154 | name=container_name, 155 | ports={"5000": port}, 156 | environment={"HF_MODEL_ID": model, "HF_TASK": task}, 157 | detach=True, 158 | # GPU 159 | device_requests=device_request, 160 | ) 161 | time.sleep(5) 162 | 163 | verify_task(task=task, port=port) 164 | container.stop() 165 | container.remove() 166 | 167 | 168 | @require_torch 169 | @pytest.mark.parametrize( 170 | "task", 171 | [ 172 | # transformers 173 | # TODO: "visual-question-answering" and "zero-shot-image-classification" not supported yet due to multimodality input 174 | "text-classification", 175 | "zero-shot-classification", 176 | "token-classification", 177 | "question-answering", 178 | "fill-mask", 179 | "summarization", 180 | "translation_xx_to_yy", 181 | "text2text-generation", 182 | "text-generation", 183 | "feature-extraction", 184 | "image-classification", 185 | "automatic-speech-recognition", 186 | "audio-classification", 187 | "object-detection", 188 | "image-segmentation", 189 | "table-question-answering", 190 | "conversational", 191 | "image-text-to-text", 192 | # sentence-transformers 193 | "sentence-similarity", 194 | "sentence-embeddings", 195 | "sentence-ranking", 196 | # diffusers 197 | "text-to-image", 198 | ], 199 | ) 200 | def test_pt_container_local_model(task: str) -> None: 201 | container_name = f"integration-test-{task}" 202 | container_image = f"starlette-transformers:{DEVICE}" 203 | framework = "pytorch" 204 | model = task2model[task][framework] 205 | port = random.randint(5000, 6000) 206 | device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] 207 | make_sure_other_containers_are_stopped(client, container_name) 208 | with tempfile.TemporaryDirectory() as tmpdirname: 209 | # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py 210 | _load_repository_from_hf(model, tmpdirname, framework="pytorch") 211 | container = client.containers.run( 212 | container_image, 213 | name=container_name, 214 | ports={"5000": port}, 215 | environment={"HF_MODEL_DIR": "/opt/huggingface/model", "HF_TASK": task}, 216 | volumes={tmpdirname: {"bind": "/opt/huggingface/model", "mode": "ro"}}, 217 | detach=True, 218 | # GPU 219 | device_requests=device_request, 220 | ) 221 | # time.sleep(5) 222 | verify_task(container, task, port) 223 | container.stop() 224 | container.remove() 225 | 226 | 227 | @require_torch 228 | @pytest.mark.parametrize( 229 | "repository_id", 230 | ["philschmid/custom-handler-test", "philschmid/custom-handler-distilbert"], 231 | ) 232 | def test_pt_container_custom_handler(repository_id) -> None: 233 | container_name = "integration-test-custom" 234 | container_image = f"starlette-transformers:{DEVICE}" 235 | device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] 236 | port = random.randint(5000, 6000) 237 | 238 | make_sure_other_containers_are_stopped(client, container_name) 239 | with tempfile.TemporaryDirectory() as tmpdirname: 240 | # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py 241 | _storage_dir = _load_repository_from_hf(repository_id, tmpdirname) 242 | container = client.containers.run( 243 | container_image, 244 | name=container_name, 245 | ports={"5000": port}, 246 | environment={ 247 | "HF_MODEL_DIR": tmpdirname, 248 | }, 249 | volumes={tmpdirname: {"bind": tmpdirname, "mode": "ro"}}, 250 | detach=True, 251 | # GPU 252 | device_requests=device_request, 253 | ) 254 | BASE_URL = f"http://localhost:{port}" 255 | wait_for_container_to_be_ready(BASE_URL) 256 | payload = {"inputs": "this is a test"} 257 | prediction = requests.post(f"{BASE_URL}", json=payload).json() 258 | assert prediction == payload 259 | # time.sleep(5) 260 | container.stop() 261 | container.remove() 262 | 263 | 264 | @require_torch 265 | @pytest.mark.parametrize( 266 | "repository_id", 267 | ["philschmid/custom-pipeline-text-classification"], 268 | ) 269 | def test_pt_container_legacy_custom_pipeline(repository_id: str) -> None: 270 | container_name = "integration-test-custom" 271 | container_image = f"starlette-transformers:{DEVICE}" 272 | device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] 273 | port = random.randint(5000, 6000) 274 | 275 | make_sure_other_containers_are_stopped(client, container_name) 276 | with tempfile.TemporaryDirectory() as tmpdirname: 277 | # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py 278 | _storage_dir = _load_repository_from_hf(repository_id, tmpdirname) 279 | container = client.containers.run( 280 | container_image, 281 | name=container_name, 282 | ports={"5000": port}, 283 | environment={ 284 | "HF_MODEL_DIR": tmpdirname, 285 | }, 286 | volumes={tmpdirname: {"bind": tmpdirname, "mode": "ro"}}, 287 | detach=True, 288 | # GPU 289 | device_requests=device_request, 290 | ) 291 | BASE_URL = f"http://localhost:{port}" 292 | wait_for_container_to_be_ready(BASE_URL) 293 | payload = {"inputs": "this is a test"} 294 | prediction = requests.post(f"{BASE_URL}", json=payload).json() 295 | assert prediction == payload 296 | # time.sleep(5) 297 | container.stop() 298 | container.remove() 299 | 300 | 301 | @require_tf 302 | @pytest.mark.parametrize( 303 | "task", 304 | [ 305 | "text-classification", 306 | "zero-shot-classification", 307 | "token-classification", 308 | "question-answering", 309 | "fill-mask", 310 | "summarization", 311 | "translation_xx_to_yy", 312 | "text2text-generation", 313 | "text-generation", 314 | "feature-extraction", 315 | "image-classification", 316 | "automatic-speech-recognition", 317 | "audio-classification", 318 | "object-detection", 319 | "image-segmentation", 320 | "table-question-answering", 321 | "conversational", 322 | # TODO currently not supported due to multimodality input 323 | # "visual-question-answering", 324 | # "zero-shot-image-classification", 325 | "sentence-similarity", 326 | "sentence-embeddings", 327 | "sentence-ranking", 328 | ], 329 | ) 330 | def test_tf_container_remote_model(task) -> None: 331 | container_name = f"integration-test-{task}" 332 | container_image = f"starlette-transformers:{DEVICE}" 333 | framework = "tensorflow" 334 | model = task2model[task][framework] 335 | device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] 336 | if model is None: 337 | pytest.skip("no supported TF model") 338 | port = random.randint(5000, 6000) 339 | make_sure_other_containers_are_stopped(client, container_name) 340 | container = client.containers.run( 341 | container_image, 342 | name=container_name, 343 | ports={"5000": port}, 344 | environment={"HF_MODEL_ID": model, "HF_TASK": task}, 345 | detach=True, 346 | # GPU 347 | device_requests=device_request, 348 | ) 349 | # time.sleep(5) 350 | verify_task(container, task, port) 351 | container.stop() 352 | container.remove() 353 | 354 | 355 | @require_tf 356 | @pytest.mark.parametrize( 357 | "task", 358 | [ 359 | "text-classification", 360 | "zero-shot-classification", 361 | "token-classification", 362 | "question-answering", 363 | "fill-mask", 364 | "summarization", 365 | "translation_xx_to_yy", 366 | "text2text-generation", 367 | "text-generation", 368 | "feature-extraction", 369 | "image-classification", 370 | "automatic-speech-recognition", 371 | "audio-classification", 372 | "object-detection", 373 | "image-segmentation", 374 | "table-question-answering", 375 | "conversational", 376 | # TODO currently not supported due to multimodality input 377 | # "visual-question-answering", 378 | # "zero-shot-image-classification", 379 | "sentence-similarity", 380 | "sentence-embeddings", 381 | "sentence-ranking", 382 | ], 383 | ) 384 | def test_tf_container_local_model(task) -> None: 385 | container_name = f"integration-test-{task}" 386 | container_image = f"starlette-transformers:{DEVICE}" 387 | framework = "tensorflow" 388 | model = task2model[task][framework] 389 | device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] 390 | if model is None: 391 | pytest.skip("no supported TF model") 392 | port = random.randint(5000, 6000) 393 | make_sure_other_containers_are_stopped(client, container_name) 394 | with tempfile.TemporaryDirectory() as tmpdirname: 395 | # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py 396 | _storage_dir = _load_repository_from_hf(model, tmpdirname, framework=framework) 397 | container = client.containers.run( 398 | container_image, 399 | name=container_name, 400 | ports={"5000": port}, 401 | environment={"HF_MODEL_DIR": "/opt/huggingface/model", "HF_TASK": task}, 402 | volumes={tmpdirname: {"bind": "/opt/huggingface/model", "mode": "ro"}}, 403 | detach=True, 404 | # GPU 405 | device_requests=device_request, 406 | ) 407 | # time.sleep(5) 408 | verify_task(container, task, port) 409 | container.stop() 410 | container.remove() 411 | 412 | 413 | # @require_tf 414 | # @pytest.mark.parametrize( 415 | # "repository_id", 416 | # ["philschmid/custom-pipeline-text-classification"], 417 | # ) 418 | # def test_tf_cpu_container_custom_pipeline(repository_id) -> None: 419 | # container_name = "integration-test-custom" 420 | # container_image = "starlette-transformers:cpu" 421 | # make_sure_other_containers_are_stopped(client, container_name) 422 | # with tempfile.TemporaryDirectory() as tmpdirname: 423 | # # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py 424 | # storage_dir = _load_repository_from_hf("philschmid/custom-pipeline-text-classification", tmpdirname) 425 | # container = client.containers.run( 426 | # container_image, 427 | # name=container_name, 428 | # ports={"5000": "5000"}, 429 | # environment={ 430 | # "HF_MODEL_DIR": tmpdirname, 431 | # }, 432 | # volumes={tmpdirname: {"bind": tmpdirname, "mode": "ro"}}, 433 | # detach=True, 434 | # # GPU 435 | # # device_requests=[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] 436 | # ) 437 | # BASE_URL = "http://localhost:5000" 438 | # wait_for_container_to_be_ready(BASE_URL) 439 | # payload = {"inputs": "this is a test"} 440 | # prediction = requests.post(f"{BASE_URL}", json=payload).json() 441 | # assert prediction == payload 442 | # # time.sleep(5) 443 | # container.stop() 444 | # container.remove() 445 | --------------------------------------------------------------------------------