├── text-generation-inference ├── integration-tests │ ├── pytest.ini │ ├── requirements.txt │ └── test_model.py ├── server │ ├── build-requirements.txt │ ├── text_generation_server │ │ ├── version.py │ │ ├── jetstream_pt_support │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── gemma_model_hf.py │ │ │ │ ├── mixtral_model_hf.py │ │ │ │ └── llama_model_exportable_hf.py │ │ │ ├── __init__.py │ │ │ ├── compatibility.py │ │ │ └── logits_process.py │ │ ├── interceptor.py │ │ ├── auto_generator.py │ │ ├── generator_base.py │ │ ├── server.py │ │ └── cli.py │ ├── pyproject.toml │ └── Makefile ├── tests │ ├── pytest.ini │ ├── helpers.py │ ├── test_prefill_truncate.py │ ├── conftest.py │ ├── test_decode.py │ ├── test_warmup.py │ ├── test_decode_jetstream_quant.py │ ├── decode_tests_utils.py │ ├── test_generator_slot.py │ └── test_decode_jetstream.py ├── Cargo.toml ├── docker │ ├── entrypoint.sh │ └── Dockerfile └── README.md ├── docs ├── scripts │ ├── examples_list.yml │ └── auto-generate-examples.py └── source │ ├── supported-architectures.mdx │ ├── howto │ ├── training.mdx │ ├── installation_inside_a_container.mdx │ ├── more_examples.mdx │ ├── gcloud_cli.mdx │ ├── serving.mdx │ ├── advanced-tgi-serving.mdx │ └── deploy_instance_on_ie.mdx │ ├── conceptual_guides │ ├── difference_between_jetstream_and_xla.mdx │ └── tpu_hardware_support.mdx │ ├── optimum_container.mdx │ ├── installation.mdx │ ├── _toctree.yml │ ├── contributing.mdx │ ├── reference │ ├── fsdp_v2.mdx │ └── tgi_advanced_options.mdx │ ├── tutorials │ ├── tpu_setup.mdx │ ├── training_on_tpu.mdx │ └── inference_on_tpu.mdx │ └── index.mdx ├── examples ├── README.md └── text-generation │ └── generation.py ├── requirements.txt ├── .github ├── workflows │ ├── secrets-leak.yml │ ├── upload_pr_documentation.yml │ ├── pypi-release.yaml │ ├── test-pytorch-xla-tpu.yml │ ├── test-pytorch-xla-tpu-tgi.yml │ ├── test-pytorch-xla-tpu-tgi-jetstream.yml │ ├── test-pytorch-xla-tpu-tgi-integration.yml │ ├── check_code_quality.yml │ ├── test-pytorch-xla-tpu-tgi-nightly.yml │ ├── doc-pr-build.yml │ ├── doc-build.yml │ ├── test-pytorch-xla-tpu-tgi-nightly-jetstream.yml │ └── tpu-tgi-release.yml └── pull_request_template.md ├── setup.cfg ├── optimum └── tpu │ ├── jetstream_pt_support.py │ ├── xla_logger.py │ ├── version.py │ ├── generation │ ├── __init__.py │ └── logits_process.py │ ├── __init__.py │ ├── model.py │ ├── xla_mp_comm.py │ ├── static_cache_xla.py │ ├── modeling.py │ ├── fsdp_v2.py │ ├── cli.py │ └── distributed_model.py ├── MANIFEST.in ├── tests ├── conftest.py └── test_distributed_model.py ├── .gitignore ├── pyproject.toml ├── README.md └── Makefile /text-generation-inference/integration-tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_mode = auto 3 | -------------------------------------------------------------------------------- /text-generation-inference/server/build-requirements.txt: -------------------------------------------------------------------------------- 1 | build 2 | grpcio-tools==1.53.0 3 | mypy-protobuf -------------------------------------------------------------------------------- /docs/scripts/examples_list.yml: -------------------------------------------------------------------------------- 1 | - local: howto/gemma_tuning 2 | title: Gemma Fine-Tuning Example 3 | - local: howto/llama_tuning 4 | title: Llama Fine-Tuning Example -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # 🤗 Transformers example scripts on Google TPU with Optimum TPU 2 | 3 | These examples show how to run an inference on Google TPU using Optimum TPU. 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This is not a complete list of dependencies, but it allows to install torch without CUDA support 2 | --index-url https://download.pytorch.org/whl/cpu 3 | torch==2.5.1 4 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/version.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | 3 | 4 | __version__ = "0.2.3.dev0" 5 | VERSION = parse_version(__version__) 6 | -------------------------------------------------------------------------------- /text-generation-inference/tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | jetstream: mark a test as a test that uses jetstream backend 4 | torch_xla: mark a test as a test that uses torch_xla backend 5 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gemma_model_hf import GemmaModelHf as GemmaModel 2 | from .llama_model_exportable_hf import TransformerHf as LlamaModel 3 | from .mixtral_model_hf import MixtralModelHf as MixtralModel 4 | -------------------------------------------------------------------------------- /.github/workflows/secrets-leak.yml: -------------------------------------------------------------------------------- 1 | name: Secret Leaks 2 | 3 | on: push 4 | 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | trufflehog: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v4 14 | with: 15 | fetch-depth: 0 16 | - name: Secret Scanning 17 | uses: trufflesecurity/trufflehog@main -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = optimum.tpu 7 | line_length = 119 8 | lines_after_imports = 2 9 | multi_line_output = 3 10 | use_parentheses = True 11 | 12 | [flake8] 13 | ignore = E203, E501, E741, W503, W605 14 | max-line-length = 119 15 | 16 | [tool:pytest] 17 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS 18 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: optimum-tpu 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} 17 | -------------------------------------------------------------------------------- /optimum/tpu/jetstream_pt_support.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def jetstream_pt_available() -> bool: 5 | """Check if the necessary imports to use jetstream_pt are available. 6 | """ 7 | try: 8 | # Jetstream Pytorch is enabled by default, it can be disabled with an ENV variable. 9 | jetstream_pt_disabled = os.environ.get("JETSTREAM_PT_DISABLE", False) == "1" 10 | if jetstream_pt_disabled: 11 | return False 12 | # Import torch_xla2 first! 13 | import torch_xla2 # noqa: F401, isort:skip 14 | 15 | import jetstream_pt # noqa: F401 16 | 17 | return True 18 | except ImportError: 19 | return False 20 | -------------------------------------------------------------------------------- /docs/source/supported-architectures.mdx: -------------------------------------------------------------------------------- 1 | # Supported Models 2 | 3 | ## Inference 4 | The following LLMs have been tested and validated for inference on TPU v5e and v6e for text generation: 5 | 6 | - 🦙 LLaMA Family 7 | - LLaMA-2 7B 8 | - LLaMA-3 8B, 70B 9 | - LlaMa3.1 8B, 70B 10 | - LLaMA-3.2 1B, 3B (text-only models) 11 | - LlaMa-3.3 70B 12 | - 💎 Gemma Family 13 | - Gemma 2B, 7B 14 | - 💨 Mistral Family 15 | - Mistral 7B 16 | - Mixtral 8x7B 17 | 18 | ## Fine-tuning 19 | The following models have been tested and validated for fine-tuning on TPU v5e and v6e: 20 | 21 | - 🦙 LLaMA Family 22 | - LLaMA-2 7B 23 | - LLaMA-3 8B 24 | - LLaMA-3.2 1B 25 | - 💎 Gemma Family 26 | - Gemma 2B 27 | - Gemma 7B -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | include README.md 16 | include LICENSE 17 | 18 | -------------------------------------------------------------------------------- /optimum/tpu/xla_logger.py: -------------------------------------------------------------------------------- 1 | import torch_xla.core.xla_model as xm 2 | from loguru import logger 3 | 4 | 5 | """ 6 | This is just a shallow wrapper to loguru's logger, to only log messages on the master ordinal and avoid repeating 7 | messages on all the other ordinals threads. 8 | """ 9 | 10 | def warning(message: str): 11 | if xm.get_ordinal() == 0: 12 | logger.opt(depth=1).warning(message) 13 | 14 | def info(message: str): 15 | if xm.get_ordinal() == 0: 16 | logger.opt(depth=1).info(message) 17 | 18 | def debug(message: str): 19 | if xm.get_ordinal() == 0: 20 | logger.opt(depth=1).debug(message) 21 | 22 | def error(message: str): 23 | if xm.get_ordinal() == 0: 24 | logger.opt(depth=1).error(message) 25 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | # See https://stackoverflow.com/a/61193490/217945 for run_slow 5 | def pytest_addoption(parser): 6 | parser.addoption( 7 | "--runslow", action="store_true", default=False, help="run slow tests" 8 | ) 9 | 10 | 11 | def pytest_configure(config): 12 | config.addinivalue_line("markers", "slow: mark test as slow to run") 13 | 14 | 15 | def pytest_collection_modifyitems(config, items): 16 | if config.getoption("--runslow"): 17 | # --runslow given in cli: do not skip slow tests 18 | return 19 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 20 | for item in items: 21 | if "slow" in item.keywords: 22 | item.add_marker(skip_slow) 23 | -------------------------------------------------------------------------------- /.github/workflows/pypi-release.yaml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | on: 3 | release: 4 | types: [published] 5 | workflow_dispatch: 6 | 7 | env: 8 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_DIST }} 9 | 10 | jobs: 11 | upload_package: 12 | name: Upload package to PyPI 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | 18 | - name: Install dependencies 19 | run: pip install -U build twine 20 | 21 | - name: Clean 22 | run: | 23 | rm -rf build/ 24 | rm -rf dist/ 25 | 26 | - name: Build the wheels 27 | run: python -m build . 28 | 29 | - name: Upload to PyPI 30 | run: | 31 | pip install twine 32 | twine upload dist/* -u __token__ -p "$PYPI_TOKEN" 33 | -------------------------------------------------------------------------------- /optimum/tpu/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from packaging.version import parse 16 | 17 | 18 | __version__ = "0.2.3.dev0" 19 | VERSION = parse(__version__) 20 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/jetstream_pt_support/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .compatibility import create_engine, model_can_use_jetstream_pt 16 | -------------------------------------------------------------------------------- /optimum/tpu/generation/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .logits_process import FusedLogitsWarper # noqa: F401 17 | from .token_selector import TokenSelector # noqa: F401 18 | -------------------------------------------------------------------------------- /text-generation-inference/integration-tests/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | text-generation >= 0.6.0 15 | pytest >= 7.4.0 16 | pytest-asyncio >= 0.21.1 17 | docker >= 6.1.3 18 | Levenshtein 19 | loguru 20 | -------------------------------------------------------------------------------- /optimum/tpu/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .jetstream_pt_support import jetstream_pt_available # isort:skip 16 | from .fsdp_v2 import get_fsdp_config, use_fsdp_v2 17 | from .version import VERSION, __version__ 18 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/interceptor.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import grpc 4 | from google.rpc import code_pb2, status_pb2 5 | from grpc_interceptor.server import AsyncServerInterceptor 6 | from grpc_status import rpc_status 7 | from loguru import logger 8 | 9 | 10 | class ExceptionInterceptor(AsyncServerInterceptor): 11 | async def intercept( 12 | self, 13 | method: Callable, 14 | request_or_iterator: Any, 15 | context: grpc.ServicerContext, 16 | method_name: str, 17 | ) -> Any: 18 | try: 19 | response = method(request_or_iterator, context) 20 | return await response 21 | except Exception as err: 22 | method_name = method_name.split("/")[-1] 23 | logger.exception(f"Method {method_name} encountered an error.") 24 | 25 | await context.abort_with_status( 26 | rpc_status.to_status(status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))) 27 | ) 28 | -------------------------------------------------------------------------------- /.github/workflows/test-pytorch-xla-tpu.yml: -------------------------------------------------------------------------------- 1 | name: Optimum TPU tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | - "tests/**" 8 | pull_request: 9 | branches: [ main ] 10 | paths: 11 | - "tests/**" 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | do-the-job: 19 | name: Run optimum tpu tests 20 | runs-on: 21 | group: gcp-ct5lp-hightpu-8t 22 | container: 23 | image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm 24 | options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache 25 | env: 26 | PJRT_DEVICE: TPU 27 | HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface 28 | JETSTREAM_PT_DISABLE: 1 # Disable PyTorch to avoid conflicts with PyTorch XLA 29 | steps: 30 | - name: Checkout 31 | uses: actions/checkout@v4 32 | 33 | - name: Build and test optimum tpu 34 | run: | 35 | HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tests 36 | -------------------------------------------------------------------------------- /text-generation-inference/server/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "text-generation-server" 7 | dynamic = ["version"] 8 | authors = [{name="Alvaro Moran", email="alvaro.moran@huggingface.co" }] 9 | description = "TGI compatible inference server for Google TPU platforms" 10 | dependencies = [ 11 | 'protobuf', 12 | 'grpcio == 1.62.1', 13 | 'grpcio-status == 1.62.1', 14 | 'grpcio-reflection == 1.62.1', 15 | 'grpc-interceptor == 0.15.2', 16 | 'typer == 0.6.1', 17 | 'safetensors == 0.4.5', 18 | 'transformers == 4.46.3', 19 | 'loguru == 0.6.0', 20 | "sentencepiece == 0.2.0", 21 | "numpy<2.0", 22 | ] 23 | 24 | [tool.setuptools] 25 | packages = [ 26 | "text_generation_server", 27 | "text_generation_server.pb", 28 | "text_generation_server.jetstream_pt_support", 29 | "text_generation_server.jetstream_pt_support.models", 30 | ] 31 | 32 | [tool.setuptools.dynamic] 33 | version = {attr = "text_generation_server.version.__version__"} 34 | 35 | [project.scripts] 36 | text-generation-server = 'text_generation_server.cli:app' 37 | -------------------------------------------------------------------------------- /.github/workflows/test-pytorch-xla-tpu-tgi.yml: -------------------------------------------------------------------------------- 1 | name: Optimum TPU / Test TGI on TPU 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | - "text-generation-inference/**" 8 | pull_request: 9 | branches: [ main ] 10 | paths: 11 | - "text-generation-inference/**" 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | do-the-job: 19 | name: Run TGI tests 20 | runs-on: 21 | group: gcp-ct5lp-hightpu-8t 22 | container: 23 | image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm 24 | options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache 25 | env: 26 | PJRT_DEVICE: TPU 27 | HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface 28 | JETSTREAM_PT_DISABLE: 1 # Disable PyTorch to avoid conflicts with PyTorch XLA 29 | steps: 30 | - name: Checkout 31 | uses: actions/checkout@v4 32 | 33 | - name: Build and test TGI server 34 | run: | 35 | HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test 36 | -------------------------------------------------------------------------------- /docs/source/howto/training.mdx: -------------------------------------------------------------------------------- 1 | # Training on a Google Cloud TPU instance 2 | 3 | Welcome to the 🤗 Optimum-TPU training guide! This section covers how to fine-tune models using Google Cloud TPUs. 4 | 5 | ## Supported Models 6 | 7 | See [Supported Models](../supported-architectures). 8 | 9 | ## Getting Started 10 | 11 | ### Prerequisites 12 | 13 | Before starting the training process, ensure you have: 14 | 15 | 1. A configured Google Cloud TPU instance (see [Deployment Guide](../tutorials/tpu_setup)) 16 | 2. Optimum-TPU installed with PyTorch/XLA support: 17 | ```bash 18 | pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html 19 | ``` 20 | 21 | ## Example Training Scripts 22 | 23 | You can now follow one of our several example scripts to get started: 24 | 1. Gemma Fine-tuning: 25 | - See our [Gemma fine-tuning notebook](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/gemma_tuning.ipynb) for a step-by-step guide 26 | 27 | 2. LLaMA Fine-tuning: 28 | - Check our [LLaMA fine-tuning notebook](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/llama_tuning.ipynb) for detailed instructions 29 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 12 | 13 | 14 | 15 | Fixes # (issue) 16 | 17 | 18 | ## Before submitting 19 | - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). 20 | - [ ] Did you make sure to update the documentation with your changes? 21 | - [ ] Did you write any new necessary tests? 22 | -------------------------------------------------------------------------------- /.github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml: -------------------------------------------------------------------------------- 1 | name: Optimum TPU / Test TGI on TPU / Jetstream Pytorch 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | - "text-generation-inference/**" 8 | pull_request: 9 | branches: [ main ] 10 | paths: 11 | - "text-generation-inference/**" 12 | # This can be used to trigger workflow from the web interface 13 | workflow_dispatch: 14 | 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | do-the-job: 21 | name: Run TGI tests - Jetstream Pytorch 22 | runs-on: 23 | group: gcp-ct5lp-hightpu-8t 24 | container: 25 | image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm 26 | options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache 27 | env: 28 | PJRT_DEVICE: TPU 29 | HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface 30 | steps: 31 | - name: Checkout 32 | uses: actions/checkout@v4 33 | 34 | - name: Build and test TGI server 35 | run: | 36 | HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tgi_test_jetstream 37 | -------------------------------------------------------------------------------- /text-generation-inference/Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "backends/v2", 4 | "backends/grpc-metadata", 5 | "launcher", 6 | "router" 7 | ] 8 | default-members = [ 9 | "backends/v2", 10 | "backends/grpc-metadata", 11 | "launcher", 12 | "router" 13 | ] 14 | resolver = "2" 15 | 16 | [workspace.package] 17 | version = "3.0.0" 18 | edition = "2021" 19 | authors = ["Olivier Dehaene"] 20 | homepage = "https://github.com/huggingface/text-generation-inference" 21 | 22 | [workspace.dependencies] 23 | base64 = "0.22.0" 24 | tokenizers = { version = "0.20.0", features = ["http"] } 25 | hf-hub = { version = "0.3.1", features = ["tokio"] } 26 | metrics = { version = "0.23.0" } 27 | metrics-exporter-prometheus = { version = "0.15.1", features = [] } 28 | minijinja = { version = "2.2.0", features = ["json"] } 29 | minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } 30 | pyo3 = { version = "0.22.2", features = ["auto-initialize"] } 31 | 32 | [profile.release] 33 | incremental = true 34 | 35 | [profile.release-binary] 36 | inherits = "release" 37 | debug = 1 38 | incremental = true 39 | panic = "abort" 40 | 41 | [profile.release-opt] 42 | inherits = "release" 43 | debug = 0 44 | incremental = false 45 | lto = "fat" 46 | opt-level = 3 47 | codegen-units = 1 -------------------------------------------------------------------------------- /.github/workflows/test-pytorch-xla-tpu-tgi-integration.yml: -------------------------------------------------------------------------------- 1 | name: Optimum TPU / Test TGI on TPU / Integration Tests 2 | 3 | on: 4 | # schedule: 5 | # - cron: '0 4 * * *' # run at 4 AM UTC 6 | # # This can be used to allow manually triggering nightlies from the web interface 7 | workflow_dispatch: 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | integration-tests: 15 | name: Run TGI Integration Tests 16 | runs-on: 17 | group: gcp-ct5lp-hightpu-8t 18 | 19 | env: 20 | PJRT_DEVICE: TPU 21 | HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface 22 | HF_TOKEN: ${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} 23 | TPU_ENV: ${{ vars.V5_LITEPOD_8_ENV}} 24 | 25 | steps: 26 | - name: Checkout code 27 | uses: actions/checkout@v4 28 | 29 | - name: Install Python 30 | run: | 31 | sudo apt-get update -y 32 | sudo apt-get install -y python3 python3-pip 33 | sudo ln -s /usr/bin/python3 /usr/bin/python 34 | 35 | # To build the docker image in the ci, we need to use the network host option 36 | - name: Build TGI Docker Image 37 | run: | 38 | make tpu-tgi NETWORK=host 39 | 40 | - name: Run integration tests 41 | run: | 42 | make tgi_docker_test 43 | -------------------------------------------------------------------------------- /docs/source/conceptual_guides/difference_between_jetstream_and_xla.mdx: -------------------------------------------------------------------------------- 1 | # Differences between Jetstream Pytorch and PyTorch XLA 2 | 3 | This guide explains to optimum-tpu users the difference between Jetstream Pytorch and PyTorch XLA as those are two available backends in TGI. 4 | 5 | JetStream PyTorch is a high-performance inference engine built on top of PyTorch XLA. It is optimized for throughput and memory efficiency when running Large Language Models (LLMs) on TPUs. 6 | 7 | | Feature | Jetstream Pytorch | PyTorch XLA | 8 | |---------|-----------|-------------| 9 | | Training | ❌ | ✅ | 10 | | Serving | ✅ | ✅ | 11 | | Performance | Higher serving performance | Standard performance | 12 | | Flexibility | Limited to serving | Full PyTorch ecosystem | 13 | | Use Case | Production inference | Development and training | 14 | | Integration | Optimized for deployment | Standard PyTorch workflow | 15 | 16 | **Notes:** 17 | By default, optimum-tpu is using PyTorch XLA for training and Jetstream Pytorch for serving. 18 | 19 | You can configure optimum-tpu to use either version for serving with TGI. You can use the Pytorch XLA backend in TGI by setting up `-e JETSTREAM_PT_DISABLE=1` in your docker run arguments. 20 | 21 | You can find more information about: 22 | - PyTorch XLA: https://pytorch.org/xla/ and https://github.com/pytorch/xla 23 | - Jetstream Pytorch: https://github.com/AI-Hypercomputer/jetstream-pytorch -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py: -------------------------------------------------------------------------------- 1 | 2 | from jetstream_pt.third_party.gemma import config as gemma_config 3 | from jetstream_pt.third_party.gemma.model import GemmaModel 4 | from transformers import GemmaConfig, GenerationConfig, GenerationMixin 5 | 6 | 7 | class GemmaConfigHf(GemmaConfig, gemma_config.GemmaConfig): 8 | """This class is used to support both the HF GemmaConfig and the Jetstream Pytorch GemmaConfig at the same time. 9 | """ 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.tokenizer = None 14 | 15 | 16 | class GemmaModelHf(GemmaModel, GenerationMixin): 17 | """Transformer module that uses HF GemmaConfig instead of Jetstream Pytorch GemmaConfig + device. 18 | 19 | Note that this class also derives from GenerationMixin, so that we can use its methods. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | config: GemmaConfig, 25 | device, 26 | env, 27 | ): 28 | self.generation_config = GenerationConfig.from_model_config(config) 29 | args = GemmaConfigHf(**config.to_dict()) 30 | args.device = device 31 | super().__init__(args, env) 32 | 33 | 34 | @classmethod 35 | def from_config(cls, config, env): 36 | device = "meta" 37 | model = cls(config, device, env) 38 | return model 39 | -------------------------------------------------------------------------------- /.github/workflows/check_code_quality.yml: -------------------------------------------------------------------------------- 1 | name: check_code_quality 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | - "setup.py" 8 | - "optimum/tpu/**.py" 9 | - "tests/**.py" 10 | - "examples/**.py" 11 | 12 | pull_request: 13 | branches: [ main ] 14 | paths: 15 | - "setup.py" 16 | - "optimum/tpu/**.py" 17 | - "tests/**.py" 18 | - "examples/**.py" 19 | 20 | concurrency: 21 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 22 | cancel-in-progress: true 23 | 24 | jobs: 25 | build: 26 | strategy: 27 | fail-fast: false 28 | matrix: 29 | python-version: [3.10.12] 30 | os: [ubuntu-22.04] 31 | 32 | runs-on: ${{ matrix.os }} 33 | steps: 34 | - uses: actions/checkout@v4 35 | - name: Setup Python ${{ matrix.python-version }} 36 | uses: actions/setup-python@v2 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | - name: Create and start a virtual environment 40 | run: | 41 | python -m venv venv 42 | source venv/bin/activate 43 | - name: Install dependencies 44 | run: | 45 | source venv/bin/activate 46 | pip install --upgrade pip 47 | pip install .[quality] -f https://storage.googleapis.com/libtpu-releases/index.html 48 | - name: Check style with ruff 49 | run: | 50 | source venv/bin/activate 51 | ruff check . 52 | -------------------------------------------------------------------------------- /text-generation-inference/tests/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from text_generation_server.pb.generate_pb2 import ( 4 | NextTokenChooserParameters, 5 | Request, 6 | StoppingCriteriaParameters, 7 | ) 8 | 9 | from optimum.tpu.model import fetch_model 10 | 11 | 12 | def prepare_model(model_id, sequence_length): 13 | # Add variables to environment so they can be used in AutoModelForCausalLM 14 | os.environ["HF_SEQUENCE_LENGTH"] = str(sequence_length) 15 | path = fetch_model(model_id) 16 | return path 17 | 18 | 19 | def create_request( 20 | id: int, 21 | inputs: str, 22 | max_new_tokens=20, 23 | do_sample: bool = False, 24 | top_k: int = 50, 25 | top_p: float = 0.9, 26 | temperature: float = 1.0, 27 | seed: int = 0, 28 | repetition_penalty: float = 1.0, 29 | ): 30 | # For these tests we can safely set typical_p to 1.0 (default) 31 | typical_p = 1.0 32 | if not do_sample: 33 | # Drop top_p parameter to avoid warnings 34 | top_p = 1.0 35 | parameters = NextTokenChooserParameters( 36 | temperature=temperature, 37 | top_k=top_k, 38 | top_p=top_p, 39 | do_sample=do_sample, 40 | seed=seed, 41 | repetition_penalty=repetition_penalty, 42 | typical_p=typical_p, 43 | ) 44 | stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens) 45 | return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters) 46 | -------------------------------------------------------------------------------- /docs/source/conceptual_guides/tpu_hardware_support.mdx: -------------------------------------------------------------------------------- 1 | # TPU hardware support 2 | Optimum-TPU support and is optimized for v5e and v6e TPUs. 3 | 4 | ## TPU naming convention 5 | The TPU naming follows this format: `-` 6 | 7 | TPU version: 8 | - v5litepod (v5e) 9 | - v6e 10 | 11 | For example, a v5litepod-8 is a v5e TPU with 8 tpus. 12 | 13 | ## Memory on TPU 14 | The HBM (High Bandwidth Memory) capacity per chip is 16GB for v5e, v5p and 32GB for v6e. So a v5e-8 (v5litepod-8), has 16GB*8=128GB of HBM memory 15 | 16 | ## Recommended Runtime for TPU 17 | 18 | During the TPU VM creation use the following TPU VM base images for optimum-tpu: 19 | - v2-alpha-tpuv6e (TPU v6e) (recommended) 20 | - v2-alpha-tpuv5 (TPU v5p) (recommended) 21 | - v2-alpha-tpuv5-lite (TPU v5e) (recommended) 22 | - tpu-ubuntu2204-base (default) 23 | 24 | For installation instructions, refer to our [TPU setup tutorial](../tutorials/tpu_setup). We recommend you use the *alpha* version with optimum-tpu, as optimum-tpu is tested and optimized for those. 25 | 26 | More information at https://cloud.google.com/tpu/docs/runtimes#pytorch_and_jax 27 | 28 | # Next steps 29 | For more information on the different TPU hardware, you can look at: 30 | https://cloud.google.com/tpu/docs/v6e 31 | https://cloud.google.com/tpu/docs/v5p 32 | https://cloud.google.com/tpu/docs/v5e 33 | 34 | Pricing informatin can be found here https://cloud.google.com/tpu/pricing 35 | 36 | TPU availability can be found https://cloud.google.com/tpu/docs/regions-zones -------------------------------------------------------------------------------- /docs/source/optimum_container.mdx: -------------------------------------------------------------------------------- 1 | # Optimum TPU Containers 2 | 3 | ## Text Generation Inference (TGI) Containers 4 | | Container | Description | Optimum TPU | Image URL | 5 | |-----------|-------------|-------------|-----------| 6 | | TGI Base | TPU-optimized TGI without GCP dependencies | 0.2.3 | `ghcr.io/huggingface/optimum-tpu:v0.2.3-tgi` | 7 | | TGI GCP | TPU-optimized TGI with GCP dependencies | 0.2.3 | `us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-generation-inference-tpu.0.2.3.py310` | 8 | 9 | ## Training Containers 10 | | Container | Description | PyTorch | Transformers | Image URL | 11 | |-----------|-------------|----------|--------------|-----------| 12 | | Training GCP | PyTorch training with GCP dependencies | 2.5.1 | 4.46.3 | `us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-training-tpu.2.5.1.transformers.4.46.3.py310` | 13 | 14 | Each container is optimized for specific use cases: 15 | - TGI Base is a barebone TGI server optimized for TPU 16 | - TGI GCP contains some extra GCP dependency and is hosted on GCP. This is the recommended way to deploy TGI on GCP 17 | - Training GCP container for training models on TPU VMs 18 | 19 | ## Version Information 20 | Each version on GCP is pinned to specific versions of optimum-tpu, PyTorch, and/or transformers. To check the latest available images: 21 | 22 | - [latest TGI GCP images](https://github.com/huggingface/Google-Cloud-Containers/tree/main/containers/tgi/tpu) 23 | - [latest Training GCP images](https://github.com/huggingface/Google-Cloud-Containers/tree/main/containers/pytorch/training/tpu) -------------------------------------------------------------------------------- /text-generation-inference/tests/test_prefill_truncate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from helpers import create_request, prepare_model 3 | from text_generation_server.auto_generator import AutoGenerator 4 | from text_generation_server.pb.generate_pb2 import Batch 5 | 6 | 7 | @pytest.mark.jetstream 8 | @pytest.mark.torch_xla 9 | def test_prefill_truncate(): 10 | model_id = "Maykeye/TinyLLama-v0" 11 | sequence_length = 1024 12 | 13 | model_path = prepare_model(model_id, sequence_length) 14 | max_new_tokens = 20 15 | 16 | generator = AutoGenerator.from_pretrained( 17 | model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length 18 | ) 19 | input_text = "And to finish the story, I will say that" 20 | 21 | request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) 22 | batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) 23 | generations, _ = generator.prefill(batch) 24 | assert len(generations) == 1 25 | assert generations[0].tokens.ids == [357] 26 | assert generations[0].tokens.texts == [" it"] 27 | # Now re-test but with truncate 28 | generator.clear() 29 | 30 | request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) 31 | # This will only leave last tokens 32 | request.truncate = 3 33 | batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) 34 | generations, _ = generator.prefill(batch) 35 | assert len(generations) == 1 36 | assert generations[0].tokens.ids == [266] 37 | assert generations[0].tokens.texts == [" the"] 38 | -------------------------------------------------------------------------------- /text-generation-inference/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from optimum.tpu import jetstream_pt_available 6 | 7 | 8 | # See https://stackoverflow.com/a/61193490/217945 for run_slow 9 | def pytest_addoption(parser): 10 | parser.addoption( 11 | "--runslow", action="store_true", default=False, help="run slow tests" 12 | ) 13 | 14 | 15 | def pytest_configure(config): 16 | config.addinivalue_line("markers", "slow: mark test as slow to run") 17 | 18 | 19 | def pytest_collection_modifyitems(config, items): 20 | if config.getoption("--runslow"): 21 | # --runslow given in cli: do not skip slow tests 22 | return 23 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 24 | for item in items: 25 | if "slow" in item.keywords: 26 | item.add_marker(skip_slow) 27 | 28 | 29 | @pytest.fixture(scope="function") 30 | def quantization_jetstream_int8(): 31 | # Setup 32 | old_environ = dict(os.environ) 33 | os.environ["QUANTIZATION"] = "jetstream_int8" 34 | yield 35 | # Clean up 36 | os.environ.clear() 37 | os.environ.update(old_environ) 38 | 39 | 40 | def pytest_runtest_setup(item): 41 | marker_names = [marker.name for marker in item.iter_markers()] 42 | jetstream_pt_enabled = jetstream_pt_available() 43 | # Skip tests that require torch xla but not jetstream 44 | if "torch_xla" in marker_names and "jetstream" not in marker_names: 45 | if jetstream_pt_enabled: 46 | pytest.skip("Jetstream is enabled: xla test will be skipped") 47 | elif "jetstream" in marker_names and not jetstream_pt_enabled: 48 | pytest.skip("Test requires Jetstream PyTorch to be enabled") 49 | -------------------------------------------------------------------------------- /docs/source/installation.mdx: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | This assumes you already have a TPU instance running. If not, please look at [TPU setup tutorial](./tutorials/tpu_setup) 4 | 5 | If it is your first time using TPU, look at our tutorial that explains [how to setup a TPU for the first time](./tutorials/tpu_setup) 6 | 7 | This walkthrough will explain how to install the [optimum-tpu package](https://pypi.org/project/optimum-tpu/) to leverage HuggingFace's solution to run AI workloads as fast as possible on Google TPUs 🚀 8 | 9 | ## Optimum-TPU 10 | 11 | Installing the optimum-tpu python package is mainly useful for training. If you wish to do serving the recommended way to inferface with that is through [our TGI containers](./optimum_container). You can also look at our [tutorial on serving](./tutorials/inference_on_tpu) for more information. 12 | 13 | To install Optimum-TPU, it should be as simple as 14 | 15 | ```bash 16 | $ python3 -m pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html 17 | $ export PJRT_DEVICE=TPU 18 | ``` 19 | 20 | You can now leverage PyTorch/XLA through Optimum-TPU. You can validate the installation with the following command which should print `xla:0` as we do have a single 21 | TPU device bound to this instance. 22 | 23 | ```bash 24 | $ python -c "import torch_xla.core.xla_model as xm; print(xm.xla_device())" 25 | xla:0 26 | ``` 27 | 28 | You can also look at the rest at our [fine-tuning examples](./howto/more_examples) for more information on how to use the optimum-tpu package 29 | 30 | Remarks: you can also use [optimum-tpu training container](./tutorials/training_on_tpu) for a pre-setup container with optimum-tpu installed and all HuggingFace libraries pre-configured -------------------------------------------------------------------------------- /.github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml: -------------------------------------------------------------------------------- 1 | name: Optimum TPU / Test TGI on TPU (slow tests) 2 | 3 | on: 4 | # This can be used to automatically publish nightlies at UTC nighttime 5 | # schedule: 6 | # - cron: '0 2 * * *' # run at 2 AM UTC 7 | # This can be used to allow manually triggering nightlies from the web interface 8 | workflow_dispatch: 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | do-the-job: 16 | name: Build and Run slow tests 17 | runs-on: 18 | group: gcp-ct5lp-hightpu-8t 19 | container: 20 | image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm 21 | options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache 22 | env: 23 | PJRT_DEVICE: TPU 24 | HF_TOKEN: ${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} 25 | HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface 26 | JETSTREAM_PT_DISABLE: 1 # Disable PyTorch to avoid conflicts with PyTorch XLA 27 | steps: 28 | - name: Checkout 29 | uses: actions/checkout@v4 30 | 31 | - name: Build and test Optimum TPU (also slow tests) 32 | run: | 33 | python -m pip install build 34 | make build_dist test_installs 35 | python -m pytest --runslow -sv tests 36 | 37 | - name: Build and test TGI (also slow tests) 38 | run: | 39 | make tgi_server test_installs 40 | find text-generation-inference/ -name "text_generation_server-*whl" -exec python -m pip install {} \; 41 | python -m pytest --runslow -sv text-generation-inference/tests -m torch_xla 42 | -------------------------------------------------------------------------------- /text-generation-inference/docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This is required by GKE, see 4 | # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#privileged-mode 5 | ulimit -l 68719476736 6 | 7 | # Hugging Face Hub related 8 | if [[ -z "${MAX_BATCH_SIZE}" ]]; then 9 | MAX_BATCH_SIZE=4 10 | fi 11 | export MAX_BATCH_SIZE="${MAX_BATCH_SIZE}" 12 | 13 | # At some point we used to have MAX_INPUT_LENGTH, now we should use MAX_INPUT_TOKENS 14 | # (This would be done automatically by the launcher, but we need to calculate the 15 | # MAX_BATCH_PREFILL_TOKENS if not set) 16 | if [[ -z "${MAX_INPUT_TOKENS}" && -n ${MAX_INPUT_LENGTH} ]]; then 17 | MAX_INPUT_TOKENS=${MAX_INPUT_LENGTH} 18 | fi 19 | if [[ -n "${MAX_INPUT_LENGTH}" ]]; then 20 | echo "MAX_INPUT_LENGTH is deprecated, please use MAX_INPUT_TOKENS instead. Variable will be unset." 21 | unset MAX_INPUT_LENGTH 22 | fi 23 | 24 | if [[ -z "${MAX_BATCH_PREFILL_TOKENS}" ]]; then 25 | MAX_BATCH_PREFILL_TOKENS=$(( ${MAX_BATCH_SIZE} * ${MAX_INPUT_TOKENS} )) 26 | fi 27 | export MAX_BATCH_PREFILL_TOKENS="${MAX_BATCH_PREFILL_TOKENS}" 28 | 29 | if [[ -z "${JSON_OUTPUT_DISABLE}" ]]; then 30 | JSON_OUTPUT_DISABLE=--json-output 31 | else 32 | JSON_OUTPUT_DISABLE="" 33 | fi 34 | export JSON_OUTPUT_DISABLE="${JSON_OUTPUT_DISABLE}" 35 | 36 | if [[ -z "${MODEL_ID}" ]]; then 37 | echo "MODEL_ID must be set" 38 | exit 1 39 | fi 40 | export MODEL_ID="${MODEL_ID}" 41 | 42 | if [[ -z "${QUANTIZATION}" ]]; then 43 | QUANTIZATION="" 44 | else 45 | QUANTIZATION="jetstream_int8" 46 | fi 47 | export QUANTIZATION="${QUANTIZATION}" 48 | 49 | 50 | 51 | exec text-generation-launcher --port 8080 \ 52 | --max-batch-size ${MAX_BATCH_SIZE} \ 53 | ${JSON_OUTPUT_DISABLE} \ 54 | --model-id ${MODEL_ID} 55 | -------------------------------------------------------------------------------- /docs/scripts/auto-generate-examples.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | 5 | # Check that both files exist 6 | examples_file = 'docs/scripts/examples_list.yml' 7 | toctree_file = 'docs/source/_toctree.yml' 8 | 9 | if not os.path.exists(examples_file): 10 | print(f"Error: {examples_file} does not exist") 11 | sys.exit(1) 12 | 13 | if not os.path.exists(toctree_file): 14 | print(f"Error: {toctree_file} does not exist") 15 | sys.exit(1) 16 | 17 | # Read the examples list 18 | with open(examples_file, 'r') as f: 19 | examples = yaml.safe_load(f) 20 | 21 | # Read the main toctree 22 | with open(toctree_file, 'r') as f: 23 | toc = yaml.safe_load(f) 24 | 25 | # Find the howto section and insert before more_examples 26 | # Iterate through the list to find the sections with howto 27 | for item in toc: 28 | if isinstance(item, dict) and 'sections' in item: 29 | for section in item['sections']: 30 | if isinstance(section, dict) and 'sections' in section: 31 | howto_items = section['sections'] 32 | for i, subitem in enumerate(howto_items): 33 | if subitem.get('local') == 'howto/more_examples': 34 | # Insert the new examples before this position 35 | for example in reversed(examples): 36 | howto_items.insert(i, example) 37 | break 38 | 39 | # Write back the modified toctree 40 | with open(toctree_file, 'w') as f: 41 | yaml.dump(toc, f, sort_keys=False, allow_unicode=True, default_flow_style=False) 42 | 43 | print("Added examples to the howto section of the toctree") 44 | 45 | # Print the updated toctree contents 46 | with open(toctree_file, 'r') as f: 47 | print("\nUpdated _toctree.yml contents:") 48 | print(f.read()) 49 | -------------------------------------------------------------------------------- /text-generation-inference/tests/test_decode.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | from decode_tests_utils import DecodeTestParams, decode_single_test 4 | 5 | 6 | # All tests in this file are for torch xla 7 | pytestmark = pytest.mark.torch_xla 8 | 9 | @pytest.mark.parametrize("params", 10 | [ 11 | DecodeTestParams( 12 | model_id="google/gemma-2b", 13 | sequence_length=1024, 14 | expected_text="\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never", 15 | ), 16 | DecodeTestParams( 17 | model_id="Maykeye/TinyLLama-v0", 18 | sequence_length=1024, 19 | expected_text=" It was a very special day, and it was a very special day.\nThe mommy said", 20 | ), 21 | ], 22 | ids=["gemma-2b", "TinyLLama-v0"], 23 | ) 24 | def test_decode_single(params): 25 | decode_single_test(params) 26 | 27 | 28 | @pytest.mark.slow 29 | @pytest.mark.parametrize("params", 30 | [ 31 | DecodeTestParams( 32 | model_id="meta-llama/Meta-Llama-3-8B", 33 | sequence_length=256, 34 | expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,", 35 | ), 36 | DecodeTestParams( 37 | model_id="google/gemma-7b", 38 | sequence_length=128, 39 | expected_text="\n\nThe year was 1984.\n\nThe place was Oceania.\n\nThe time was", 40 | ), 41 | DecodeTestParams( 42 | model_id="mistralai/Mistral-7B-v0.3", 43 | sequence_length=128, 44 | expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the v", 45 | ), 46 | ], 47 | ids=["Meta-Llama-3-8B", "gemma-7b", "Mistral-7B-v0.3"], 48 | ) 49 | def test_decode_single_slow(params): 50 | decode_single_test(params) 51 | -------------------------------------------------------------------------------- /text-generation-inference/tests/test_warmup.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import pytest 4 | from helpers import create_request, prepare_model 5 | from text_generation_server.auto_generator import AutoGenerator 6 | from text_generation_server.pb.generate_pb2 import Batch 7 | 8 | 9 | @pytest.mark.jetstream 10 | def test_warmup_jetstream_pytorch(): 11 | model_id = "Maykeye/TinyLLama-v0" 12 | sequence_length = 256 13 | 14 | model_path = prepare_model(model_id, sequence_length) 15 | input_text = "It was a bright cold day in April, and the clocks were striking thirteen." 16 | max_new_tokens = 20 17 | 18 | generator = AutoGenerator.from_pretrained( 19 | model_path, revision="", max_batch_size=2, max_sequence_length=sequence_length 20 | ) 21 | request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) 22 | # The maximum tokens length of the model is intentionally not a power of two, to verify that prefill bucketization 23 | # works as expected (250 -> 256). 24 | max_tokens = 250 25 | batch = Batch(id=0, requests=[request], size=1, max_tokens=max_tokens) 26 | generator.warmup(batch) 27 | 28 | # Prepare a new request with different settings. Warmup should have triggered compilation so this can be run 29 | # quickly. 30 | input_text = "What is Deep Learning?" 31 | max_new_tokens = 3 32 | max_tokens = 13 33 | request1 = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) 34 | batch = Batch(id=1, requests=[request1], size=1, max_tokens=max_tokens) 35 | 36 | start = time() 37 | _generations, new_batch = generator.prefill(batch) 38 | _generations, new_batch = generator.decode([new_batch]) 39 | end = time() 40 | 41 | # Prefill and decode time should be less than 1 second (rather fast) 42 | assert end - start < 1.0 43 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py: -------------------------------------------------------------------------------- 1 | from jetstream_pt.third_party.mixtral import config as mixtral_config 2 | from jetstream_pt.third_party.mixtral.model import Transformer 3 | from transformers import GenerationConfig, GenerationMixin, MixtralConfig 4 | 5 | 6 | class MixtralConfigHf(MixtralConfig, mixtral_config.ModelArgs): 7 | """This class is used to support both the HF MixtralConfig and the Jetstream Pytorch ModelArgs at the same time.""" 8 | 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.__post_init__() 12 | 13 | @property 14 | def block_size(self): 15 | return self.max_position_embeddings 16 | 17 | @property 18 | def n_layer(self): 19 | return self.num_hidden_layers 20 | 21 | @property 22 | def n_head(self): 23 | return self.num_attention_heads 24 | 25 | @property 26 | def dim(self): 27 | return self.hidden_size 28 | 29 | @property 30 | def n_local_heads(self): 31 | return self.num_local_experts or self.num_attention_heads 32 | 33 | @property 34 | def num_activated_experts(self): 35 | return self.num_experts_per_tok 36 | 37 | 38 | class MixtralModelHf(Transformer, GenerationMixin): 39 | """Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device.""" 40 | 41 | def __init__( 42 | self, 43 | config: MixtralConfig, 44 | device, 45 | env, 46 | ): 47 | self.generation_config = GenerationConfig.from_model_config(config) 48 | args = MixtralConfigHf(**config.to_dict()) 49 | args.device = device 50 | super().__init__(args, env) 51 | 52 | 53 | @classmethod 54 | def from_config(cls, config, env): 55 | device = "meta" 56 | model = cls(config, device, env) 57 | return model 58 | -------------------------------------------------------------------------------- /docs/source/howto/installation_inside_a_container.mdx: -------------------------------------------------------------------------------- 1 | # Installing Optimum-TPU inside a Docker Container 2 | 3 | This guide explains how to run Optimum-TPU within a Docker container using the official PyTorch/XLA image. 4 | 5 | ## Prerequisites 6 | 7 | Before starting, ensure you have: 8 | - Docker installed on your system 9 | - Access to a TPU instance 10 | - Sufficient permissions to run privileged containers 11 | 12 | ## Using the PyTorch/XLA Base Image 13 | 14 | ### 1. Pull the Docker Image 15 | 16 | First, set the environment variables for the image URL and version: 17 | 18 | ```bash 19 | export TPUVM_IMAGE_URL=us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla 20 | export TPUVM_IMAGE_VERSION=r2.5.0_3.10_tpuvm 21 | 22 | # Pull the image 23 | docker pull ${TPUVM_IMAGE_URL}:${TPUVM_IMAGE_VERSION} 24 | ``` 25 | 26 | ### 2. Run the Container 27 | 28 | Launch the container with the necessary flags for TPU access: 29 | 30 | ```bash 31 | docker run -ti \ 32 | --rm \ 33 | --shm-size 16GB 34 | --privileged \ 35 | --net=host \ 36 | ${TPUVM_IMAGE_URL}@sha256:${TPUVM_IMAGE_VERSION} \ 37 | bash 38 | ``` 39 | `--shm-size 16GB --privileged --net=host` is required for docker to access the TPU 40 | 41 | ### 3. Install Optimum-TPU 42 | 43 | Once inside the container, install Optimum-TPU: 44 | 45 | ```bash 46 | pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html 47 | ``` 48 | 49 | ## Verification 50 | 51 | To verify your setup, run this simple test: 52 | 53 | ```bash 54 | python3 -c "import torch_xla.core.xla_model as xm; print(xm.xla_device())" 55 | ``` 56 | 57 | You should see output indicating the XLA device is available (e.g., `xla:0`). 58 | 59 | ## Next Steps 60 | After setting up your container, you can: 61 | - Start training models using Optimum-TPU. Refer to our [training example section](../howto/more_examples). 62 | - Run inference workloads. Check out our [serving guide](../howto/serving). 63 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any 16 | 17 | from transformers import AutoConfig 18 | 19 | from optimum.tpu import jetstream_pt_available 20 | 21 | 22 | def model_can_use_jetstream_pt(model_path: str) -> bool: 23 | """Checks if the model is supported by Jetstream Pytorch on Optimum TPU and if the required dependencies to provide 24 | the engine are installed. 25 | """ 26 | config = AutoConfig.from_pretrained(model_path) 27 | # For now few models are supported 28 | supported_models = ["llama", "gemma", "mixtral"] 29 | if config.model_type not in supported_models: 30 | return False 31 | if jetstream_pt_available(): 32 | return True 33 | return False 34 | 35 | 36 | def create_engine( 37 | model_path: str, 38 | batch_size: int, 39 | sequence_length: int, 40 | max_input_tokens: int, 41 | max_output_tokens: int, 42 | ) -> Any: 43 | if not model_can_use_jetstream_pt(model_path): 44 | # The model is not compatible with Jetstream PyTorch, just exit 45 | return None 46 | 47 | # Now import engine_loader to prevent importing it at the top when not supported 48 | from .engine_loader import create_engine 49 | return create_engine( 50 | model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens 51 | ) 52 | -------------------------------------------------------------------------------- /docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - sections: 2 | - local: index 3 | title: 🤗 Optimum-TPU 4 | - local: supported-architectures 5 | title: Supported Models 6 | - local: installation 7 | title: Installation 8 | - local: optimum_container 9 | title: Optimum TPU Containers 10 | - sections: 11 | - local: tutorials/tpu_setup 12 | title: First TPU Setup on Google Cloud 13 | - local: tutorials/inference_on_tpu 14 | title: First TPU Inference on Google Cloud 15 | - local: tutorials/training_on_tpu 16 | title: First TPU Training on Google Cloud 17 | title: Tutorials 18 | - sections: 19 | - local: howto/gcloud_cli 20 | title: Deploying and Connecting to Google TPU Instances via GCloud CLI 21 | - local: howto/serving 22 | title: Deploying a TGI server on a Google Cloud TPU instance 23 | - local: howto/training 24 | title: Training on a Google Cloud TPU instance 25 | - local: howto/deploy_instance_on_ie 26 | title: How to Deploy a Model on Inference Endpoint for Serving using TPUs 27 | - local: howto/advanced-tgi-serving 28 | title: Advanced TGI Server Configuration 29 | - local: howto/installation_inside_a_container 30 | title: Installing Optimum-TPU inside a Docker Container 31 | - local: howto/more_examples 32 | title: Find More Examples on the Optimum-TPU GitHub Repository 33 | title: How-To Guides 34 | - sections: 35 | - local: conceptual_guides/tpu_hardware_support 36 | title: TPU Hardware Support 37 | - local: conceptual_guides/difference_between_jetstream_and_xla 38 | title: Difference between Jetstream Pytorch and Pytorch XLA 39 | title: Conceptual Guides 40 | - sections: 41 | - local: reference/fsdp_v2 42 | title: FSDPv2 43 | - local: reference/tgi_advanced_options 44 | title: TGI Configuration Reference Guide 45 | title: Reference 46 | - sections: 47 | - local: contributing 48 | title: Contributing to Optimum TPU 49 | title: Contributing 50 | title: Optimum-TPU 51 | isExpanded: true 52 | -------------------------------------------------------------------------------- /.github/workflows/doc-pr-build.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | branches: [ main ] 6 | paths: 7 | - 'docs/source/**' 8 | - 'docs/assets/**' 9 | - 'optimum/**' 10 | - '.github/workflows/doc-pr-build.yml' 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | build_documentation: 18 | runs-on: ubuntu-22.04 19 | env: 20 | COMMIT_SHA: ${{ github.event.pull_request.head.sha }} 21 | PR_NUMBER: ${{ github.event.number }} 22 | EVENT_CONTEXT: ${{ toJSON(github.event) }} 23 | PR_CLONE_URL: ${{ github.event.pull_request.head.repo.clone_url }} 24 | 25 | steps: 26 | - uses: actions/checkout@v3 27 | - uses: actions/setup-node@v4 28 | with: 29 | node-version: '20' 30 | cache-dependency-path: "kit/package-lock.json" 31 | 32 | - name: Setup environment 33 | run: | 34 | pip install -U pip 35 | pip install git+https://github.com/huggingface/doc-builder.git 36 | pip install ".[quality]" -f https://storage.googleapis.com/libtpu-releases/index.html 37 | 38 | - name: Make documentation 39 | shell: bash 40 | run: | 41 | doc-builder notebook-to-mdx examples/ --output_dir docs/source/howto/ --open_notebook_prefix https://colab.research.google.com/github/huggingface/optimum-tpu/blob/main 42 | python docs/scripts/auto-generate-examples.py 43 | doc-builder build optimum.tpu docs/source/ --repo_name optimum-tpu --build_dir tpu-doc-build/ --version pr_${{ env.PR_NUMBER }} --version_tag_suffix "" --html --clean 44 | 45 | - name: Save commit_sha & pr_number 46 | run: | 47 | cd tpu-doc-build/ 48 | mv optimum.tpu optimum-tpu 49 | echo ${{ env.COMMIT_SHA }} > ./commit_sha 50 | echo ${{ env.PR_NUMBER }} > ./pr_number 51 | 52 | - uses: actions/upload-artifact@v4 53 | with: 54 | name: doc-build-artifact 55 | path: tpu-doc-build/ 56 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/auto_generator.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | from .generator_base import Generator 4 | from .jetstream_pt_support import model_can_use_jetstream_pt 5 | 6 | 7 | class AutoGenerator: 8 | 9 | @staticmethod 10 | def from_pretrained( 11 | model_path: str, revision: str, max_batch_size: int, max_sequence_length: int, max_input_tokens: int = None 12 | ) -> Generator: 13 | """Instantiate a Generator for TPU using Jetstream Pytorch or Pytorch/XLA. 14 | 15 | Args: 16 | model_path (`str`): 17 | The path to a local model. This path must also contain a Tokenizer. 18 | revision (`str`): 19 | The revision of the model. 20 | max_batch_size (`int`): 21 | The maximum batch size. 22 | max_sequence_length (`int`): 23 | The maximum sequence length. 24 | max_input_tokens (`int`): 25 | The maximum number of tokens allowed in the input. When set to None, it will be set to 80% of the 26 | `max_sequence_length`. 27 | 28 | Returns: 29 | A TpuGenerator. 30 | """ 31 | if max_input_tokens is None: 32 | max_input_tokens = int(0.8 * max_sequence_length) 33 | if model_can_use_jetstream_pt(model_path): 34 | logger.debug("Using Jetstream PyTorch generator.") 35 | from .jetstream_pt_support.generator import TpuGeneratorJetStream 36 | return TpuGeneratorJetStream.from_pretrained( 37 | model_path, 38 | revision=revision, 39 | max_batch_size=max_batch_size, 40 | max_sequence_length=max_sequence_length, 41 | max_input_tokens=max_input_tokens, 42 | ) 43 | else: 44 | logger.debug("Using PyTorch/XLA generator.") 45 | from .generator import TpuGenerator 46 | return TpuGenerator.from_pretrained( 47 | model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length 48 | ) 49 | -------------------------------------------------------------------------------- /tests/test_distributed_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | from transformers import AutoTokenizer 6 | 7 | from optimum.tpu.distributed_model import DistributedModel 8 | 9 | 10 | def sample_greedy(logits): 11 | next_logits = logits[:, -1] 12 | next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() 13 | return next_token_id 14 | 15 | 16 | def _test_distributed_model_prefill(model_id): 17 | # This test ensures model can be loaded in a parallel way and 18 | # that the "proxy" distributed model can be used to prefill the model. 19 | # Disable tokenizers parallelism to avoid deadlocks 20 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 21 | tokenizer = AutoTokenizer.from_pretrained(model_id) 22 | text = ["Running something in parallel means"] 23 | inputs = tokenizer(text, return_tensors="pt") 24 | input_ids = inputs["input_ids"] 25 | attention_mask = inputs["attention_mask"] 26 | pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) 27 | tokens = input_ids.clone() 28 | 29 | model = DistributedModel(model_id, sample_greedy) 30 | next_tokens = model.prefill(**inputs, position_ids=pos_ids) 31 | tokens = torch.cat([tokens, next_tokens], dim=-1) 32 | 33 | # Data can be decoded even before leaving 34 | decoded_texts = tokenizer.batch_decode(tokens, skip_special_tokens=True) 35 | print() 36 | print("------------------------------------------") 37 | print("Decoded texts:") 38 | print(decoded_texts[0]) 39 | print("------------------------------------------") 40 | # Even if models are different, for this simple test results are the same. 41 | expected_text = "Running something in parallel means that" 42 | assert expected_text == decoded_texts[0] 43 | 44 | 45 | def test_distributed_model_prefill_gpt2(): 46 | _test_distributed_model_prefill("openai-community/gpt2") 47 | 48 | 49 | @pytest.mark.slow 50 | def test_distributed_model_prefill_gemma7b(): 51 | _test_distributed_model_prefill("google/gemma-7b") 52 | 53 | @pytest.mark.slow 54 | def test_distributed_model_prefill_llama3_8b(): 55 | _test_distributed_model_prefill("meta-llama/Meta-Llama-3-8B") 56 | -------------------------------------------------------------------------------- /text-generation-inference/tests/test_decode_jetstream_quant.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | from decode_tests_utils import DecodeTestParams, decode_single_test 4 | 5 | 6 | # All tests in this file are for jetstream 7 | pytestmark = pytest.mark.jetstream 8 | 9 | @pytest.mark.parametrize("params", 10 | [ 11 | DecodeTestParams( 12 | model_id="google/gemma-2b", 13 | sequence_length=1024, 14 | expected_text="\n\nThe first thing I noticed was the smell of the rain. It was a very heavy rain,", 15 | ), 16 | DecodeTestParams( 17 | model_id="Maykeye/TinyLLama-v0", 18 | sequence_length=256, 19 | expected_text=" It was a very special day, and it was a very special day.\nThe mommy said to her, \"Let", 20 | max_new_tokens=25, 21 | ), 22 | ], 23 | ids=["gemma-2b", "TinyLLama-v0"], 24 | ) 25 | def test_decode_jetstream_quantization(quantization_jetstream_int8, params): 26 | decode_single_test(params) 27 | 28 | 29 | @pytest.mark.slow 30 | @pytest.mark.parametrize("params", 31 | [ 32 | DecodeTestParams( 33 | model_id="mistralai/Mixtral-8x7B-v0.1", 34 | sequence_length=1024, 35 | expected_text="\n\nGeorge Orwell, 1984\n\nThe clocks are striking thirteen", 36 | ), 37 | DecodeTestParams( 38 | model_id="meta-llama/Meta-Llama-3-8B", 39 | sequence_length=256, 40 | expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,", 41 | ), 42 | DecodeTestParams( 43 | model_id="meta-llama/Meta-Llama-3-70B", 44 | sequence_length=512, 45 | expected_text=" Winston Smith,s,s,s,s,s,s,s,s,s,s", 46 | ), 47 | DecodeTestParams( 48 | model_id="meta-llama/Llama-3.3-70B-Instruct", 49 | sequence_length=1024, 50 | expected_text=" Winston Smith, the protagonist of the story, was slowly getting up from bed. He stretched his arms", 51 | ), 52 | ], 53 | ids=["Mixtral-8x7B", "Meta-Llama-3-8B" ,"Meta-Llama-3-70B", "Llama-3.3-70B-Instruct"], 54 | ) 55 | def test_decode_jetstream_quantization_slow(quantization_jetstream_int8, params): 56 | decode_single_test(params) 57 | -------------------------------------------------------------------------------- /docs/source/howto/more_examples.mdx: -------------------------------------------------------------------------------- 1 | # Find More Examples on the Optimum-TPU GitHub Repository 2 | 3 | To find the latest examples, visit the [examples folder in the optimum-tpu repo on github](https://github.com/huggingface/optimum-tpu/tree/main/examples) 4 | 5 | ## Text Generation 6 | Learn how to perform efficient inference for text generation tasks: 7 | 8 | - **Basic Generation Script** ([examples/text-generation/generation.py](https://github.com/huggingface/optimum-tpu/blob/main/examples/text-generation/generation.py)) 9 | - Demonstrates text generation using models like Gemma and Mistral 10 | - Features greedy sampling implementation 11 | - Shows how to use static caching for improved performance 12 | - Includes performance measurement and timing analysis 13 | - Supports custom model loading and configuration 14 | 15 | ## Language Model Fine-tuning 16 | Explore how to fine-tune language models on TPU infrastructure: 17 | 18 | 1. **Interactive Gemma Tutorial** ([view in the docs](../howto/gemma_tuning)) 19 | - Complete notebook showing Gemma fine-tuning process 20 | - Covers environment setup and TPU configuration 21 | - Demonstrates FSDPv2 integration for efficient model sharding 22 | - Includes dataset preparation and PEFT/LoRA implementation 23 | - Provides step-by-step training workflow 24 | 25 | The full notebook is available at [examples/language-modeling/gemma_tuning.ipynb](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/gemma_tuning.ipynb) 26 | 27 | 28 | 2. **LLaMA Fine-tuning Guide** ([view in the docs](../howto/llama_tuning)) 29 | - Detailed guide for fine-tuning LLaMA-2 and LLaMA-3 models 30 | - Explains SPMD and FSDP concepts 31 | - Shows how to implement efficient data parallel training 32 | - Includes practical code examples and prerequisites 33 | 34 | The full notebook is available at [examples/language-modeling/llama_tuning.ipynb](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/llama_tuning.ipynb) 35 | 36 | # Additional Resources 37 | 38 | - Visit the [Optimum-TPU GitHub repository](https://github.com/huggingface/optimum-tpu) for more details 39 | - Explore the [Google Cloud TPU documentation](https://cloud.google.com/tpu/docs) for deeper understanding of TPU architecture 40 | 41 | To contribute to these examples, visit our [GitHub repository](https://github.com/huggingface/optimum-tpu). -------------------------------------------------------------------------------- /text-generation-inference/tests/decode_tests_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from helpers import create_request, prepare_model 4 | from text_generation_server.auto_generator import AutoGenerator 5 | from text_generation_server.pb.generate_pb2 import Batch 6 | from tqdm import tqdm 7 | 8 | 9 | @dataclass 10 | class DecodeTestParams: 11 | model_id: str 12 | sequence_length: int 13 | expected_text: str 14 | do_sample: bool = False 15 | max_new_tokens: int = 20 16 | top_k: int = 50 17 | repetition_penalty: float = 1.0 18 | 19 | def decode_single_test(params): 20 | model_path = prepare_model(params.model_id, params.sequence_length) 21 | input_text = "It was a bright cold day in April, and the clocks were striking thirteen." 22 | max_new_tokens = params.max_new_tokens 23 | 24 | generator = AutoGenerator.from_pretrained( 25 | model_path, revision="", max_batch_size=1, max_sequence_length=params.sequence_length 26 | ) 27 | request = create_request( 28 | id=0, 29 | inputs=input_text, 30 | max_new_tokens=max_new_tokens, 31 | do_sample=params.do_sample, 32 | top_k=params.top_k, 33 | seed=1234, 34 | repetition_penalty=params.repetition_penalty, 35 | ) 36 | batch = Batch(id=0, requests=[request], size=1, max_tokens=params.sequence_length) 37 | generations, next_batch = generator.prefill(batch) 38 | # We already generated one token: call decode max_new_tokens - 1 times 39 | for _ in tqdm(range(max_new_tokens - 1)): 40 | assert next_batch.size == 1 41 | assert next_batch.max_tokens == params.sequence_length 42 | assert len(generations) == 1 43 | assert len(generations[0].tokens.ids) == 1 44 | generations, next_batch = generator.decode([next_batch]) 45 | # Destroy generator: this will properly stop threads and prevent them from getting stuck if one of the following 46 | # assertions fails. 47 | del generator 48 | assert next_batch is None 49 | assert len(generations) == 1 50 | output = generations[0].generated_text 51 | assert output.generated_tokens == max_new_tokens 52 | assert output.finish_reason == 0 53 | print(f"Generated text: {output.text}") 54 | if params.do_sample: 55 | assert output.text != params.expected_text 56 | else: 57 | assert output.text == params.expected_text 58 | -------------------------------------------------------------------------------- /optimum/tpu/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | from huggingface_hub import snapshot_download 7 | from loguru import logger 8 | from transformers import AutoConfig 9 | from transformers.utils import SAFE_WEIGHTS_INDEX_NAME 10 | 11 | 12 | def get_export_kwargs_from_env(): 13 | batch_size = os.environ.get("MAX_BATCH_SIZE", None) 14 | if batch_size is not None: 15 | batch_size = int(batch_size) 16 | sequence_length = os.environ.get("HF_SEQUENCE_LENGTH", None) 17 | if sequence_length is not None: 18 | sequence_length = int(sequence_length) 19 | return { 20 | "task": "text-generation", 21 | "batch_size": batch_size, 22 | "sequence_length": sequence_length, 23 | } 24 | 25 | 26 | def fetch_model( 27 | model_id: str, 28 | revision: Optional[str] = None, 29 | ) -> str: 30 | """Fetch a model to local cache. 31 | 32 | Args: 33 | model_id (`str`): 34 | The *model_id* of a model on the HuggingFace hub or the path to a local model. 35 | revision (`Optional[str]`, defaults to `None`): 36 | The revision of the model on the HuggingFace hub. 37 | 38 | Returns: 39 | Model ID or path of the model available in cache. 40 | """ 41 | if os.path.isdir(model_id): 42 | if revision is not None: 43 | logger.warning("Revision {} ignored for local model at {}".format(revision, model_id)) 44 | return model_id 45 | 46 | # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) 47 | # Note that the model may already be present in the cache. 48 | start = time.time() 49 | local_path = snapshot_download( 50 | repo_id=model_id, 51 | revision=revision, 52 | allow_patterns=["*.json", "model*.safetensors", SAFE_WEIGHTS_INDEX_NAME, "tokenizer.*"], 53 | ) 54 | end = time.time() 55 | logger.info(f"Model successfully fetched in {end - start:.2f} s.") 56 | 57 | # This will allow to set config to update specific config such as 58 | # batch_size and sequence_length. 59 | export_kwargs = get_export_kwargs_from_env() 60 | config = AutoConfig.from_pretrained(local_path) 61 | config.update(export_kwargs) 62 | config.save_pretrained(local_path) 63 | end = time.time() 64 | logger.info(f"Model config updated in {end - start:.2f} s.") 65 | 66 | return Path(local_path) 67 | -------------------------------------------------------------------------------- /.github/workflows/doc-build.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - 'v[0-9]+.[0-9]+.[0-9]+' 9 | 10 | paths: 11 | - 'docs/source/**' 12 | - 'docs/assets/**' 13 | - 'optimum/**' 14 | - '.github/workflows/doc-build.yml' 15 | workflow_dispatch: 16 | 17 | jobs: 18 | build_documentation: 19 | runs-on: ubuntu-22.04 20 | env: 21 | COMMIT_SHA: ${{ github.event.pull_request.head.sha }} 22 | PR_NUMBER: ${{ github.event.number }} 23 | EVENT_CONTEXT: ${{ toJSON(github.event) }} 24 | PR_CLONE_URL: ${{ github.event.pull_request.head.repo.clone_url }} 25 | 26 | steps: 27 | - uses: actions/checkout@v4 28 | - uses: actions/setup-node@v4 29 | with: 30 | node-version: '20' 31 | cache-dependency-path: "kit/package-lock.json" 32 | 33 | - name: Set environment variables 34 | run: | 35 | cd optimum 36 | version=`echo "$(grep '^__version__ =' tpu/version.py | cut -d '=' -f 2- | xargs)"` 37 | 38 | if [[ $version == *.dev0 ]] 39 | then 40 | echo "VERSION=main" >> $GITHUB_ENV 41 | else 42 | echo "VERSION=v$version" >> $GITHUB_ENV 43 | fi 44 | 45 | cd .. 46 | 47 | - name: Setup environment 48 | run: | 49 | pip install -U pip 50 | pip install git+https://github.com/huggingface/doc-builder.git 51 | pip install ".[quality]" -f https://storage.googleapis.com/libtpu-releases/index.html 52 | 53 | - name: Make documentation 54 | shell: bash 55 | env: 56 | HF_DOC_BUILD_PUSH: ${{ secrets.HF_DOC_BUILD_PUSH }} 57 | run: | 58 | doc-builder notebook-to-mdx examples/ --output_dir docs/source/howto/ --open_notebook_prefix https://colab.research.google.com/github/huggingface/optimum-tpu/blob/main 59 | python docs/scripts/auto-generate-examples.py 60 | doc-builder build optimum.tpu docs/source/ --repo_name optimum-tpu --build_dir tpu-doc-build/ --version ${{ env.VERSION }} --version_tag_suffix "" --html --clean 61 | cd tpu-doc-build/ 62 | mv optimum.tpu optimum-tpu 63 | doc-builder push optimum-tpu --doc_build_repo_id "hf-doc-build/doc-build" --token "$HF_DOC_BUILD_PUSH" --commit_msg "Updated with commit $COMMIT_SHA See: https://github.com/huggingface/optimum-tpu/commit/$COMMIT_SHA" --n_retries 5 64 | -------------------------------------------------------------------------------- /text-generation-inference/server/Makefile: -------------------------------------------------------------------------------- 1 | # Initialize base variables 2 | pkg_name := text_generation_server 3 | BUILDDIR ?= $(CURDIR)/build 4 | VERSION ?= 0.0.1 5 | TGI_VERSION ?= "v3.0.0" 6 | mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) 7 | mkfile_dir := $(dir $(mkfile_path)) 8 | pkg_dir := $(BUILDDIR)/$(pkg_name) 9 | 10 | clean: 11 | rm -rf $(BUILDDIR)/* 12 | 13 | # List static sources to be deployed in the package 14 | src_dir := $(mkfile_dir)/$(pkg_name) 15 | rwildcard_py = $(wildcard $(1)/*.py) $(foreach d,$(wildcard $(1)/*),$(call rwildcard_py,$d)) 16 | sources := $(call rwildcard_py,$(src_dir)) 17 | deployed_sources := $(subst $(src_dir), $(pkg_dir), $(sources)) 18 | 19 | # Static files are just copied 20 | 21 | define COPY 22 | mkdir -p $(dir $@) 23 | cp -f $< $@ 24 | endef 25 | 26 | $(BUILDDIR)/pyproject.toml: $(mkfile_dir)/pyproject.toml 27 | mkdir -p $(BUILDDIR) 28 | $(COPY) 29 | sed -i -e 's/version = "VERSION"/version = \"${VERSION}\"/' $@ 30 | 31 | $(pkg_dir)/%.py: $(src_dir)/%.py 32 | mkdir -p $(pkg_dir) 33 | $(COPY) 34 | 35 | # Generated files are produced by grpcio tools 36 | 37 | # If not provided, fetch proto files from TGI 38 | ifndef PROTODIR 39 | PROTODIR := $(BUILDDIR)/tgi/proto 40 | endif 41 | 42 | $(BUILDDIR)/tgi/proto/%.proto: 43 | install -d $(BUILDDIR)/tgi 44 | curl -L https://github.com/huggingface/text-generation-inference/archive/${TGI_VERSION}.tar.gz --output $(BUILDDIR)/tgi/sources.tar.gz 45 | tar -C $(BUILDDIR)/tgi -xf $(BUILDDIR)/tgi/sources.tar.gz --strip-components=1 46 | 47 | # Three python files are generated for each protobuf 48 | protobufs := $(PROTODIR)/generate.proto 49 | pkg_pb_dir := $(pkg_dir)/pb 50 | generated_sources_base := $(foreach proto, $(protobufs), $(proto:.proto=_pb2.py)) 51 | generated_sources := $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base)) 52 | generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=.pyi)) 53 | generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=_grpc.py)) 54 | 55 | $(pkg_pb_dir)/%_pb2.py $(pkg_pb_dir)/%_pb2.pyi $(pkg_pb_dir)/%_pb2_grpc.py: $(PROTODIR)/%.proto 56 | mkdir -p $(pkg_pb_dir) 57 | python -m grpc_tools.protoc -I$(PROTODIR) --python_out=$(pkg_pb_dir) \ 58 | --grpc_python_out=$(pkg_pb_dir) --mypy_out=$(pkg_pb_dir) $^ 59 | sed -i -e 's/^\(import.*pb2\)/from . \1/g' $(pkg_pb_dir)/$*_pb2_grpc.py 60 | 61 | gen-server: $(BUILDDIR)/pyproject.toml $(deployed_sources) $(generated_sources) 62 | python -m build $(BUILDDIR) -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/generator_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import List, Optional, Tuple 3 | 4 | from .pb.generate_pb2 import ( 5 | Batch, 6 | CachedBatch, 7 | Generation, 8 | InfoResponse, 9 | ) 10 | 11 | 12 | class Generator(ABC): 13 | """An abstract class to represent the workhorse behind TextGenerationService. 14 | 15 | Ideally, it should not rely on protobuf constructs, but in a first step it does. 16 | Implementations would typically need a model and a tokenizer to implement the Generator methods. 17 | """ 18 | 19 | @property 20 | def info(self) -> InfoResponse: 21 | """This should simply return the expected InfoResponse""" 22 | raise NotImplementedError 23 | 24 | def warmup(self, batch: Batch) -> int: 25 | """Verify if the hardware can support the target load. 26 | 27 | Args: 28 | batch (`Batch`): 29 | A batch corresponding to the maximum number of concurrent requests. 30 | 31 | Return: 32 | The maximum number of tokens the model supports. 33 | """ 34 | raise NotImplementedError 35 | 36 | def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: 37 | """Prefill is called whenever new requests need to be added. 38 | 39 | When this method returns successfully, a decode method will follow 40 | with both the current and newly prefilled batch(es). 41 | 42 | Args: 43 | batch (`Batch`): 44 | A batch containing the new requests. 45 | 46 | Return: 47 | A list of `Generation` for each request and a `CachedBatch` containing all pending requests. 48 | """ 49 | raise NotImplementedError 50 | 51 | def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]: 52 | """Decode after a prefill or another decode.""" 53 | raise NotImplementedError 54 | 55 | def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: 56 | """Remove requests that are not listed from the specified batch""" 57 | raise NotImplementedError 58 | 59 | def clear(self, batch_id: Optional[int] = None): 60 | """Remove all requests from the generator""" 61 | raise NotImplementedError 62 | 63 | @classmethod 64 | def from_pretrained(cls, model_id: str, revision: Optional[str]): 65 | """Factory method "a la transformers" """ 66 | raise NotImplementedError 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.DS_Store 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # Models 133 | *.pt 134 | 135 | .vscode 136 | .idea/ 137 | 138 | jetstream-pt-deps 139 | 140 | # Optimum TPU artifacts 141 | tpu-doc-build/ -------------------------------------------------------------------------------- /docs/source/contributing.mdx: -------------------------------------------------------------------------------- 1 | # Contributing to Optimum TPU 2 | 3 | We're excited that you're interested in contributing to Optimum TPU! Whether you're fixing bugs, adding new features, improving documentation, or sharing your experiences, your contributions are highly valued 😄 4 | 5 | ## Getting Started 6 | 7 | 1. [Fork](https://github.com/huggingface/optimum-tpu/fork) and clone the repository: 8 | ```bash 9 | git clone https://github.com/YOUR_USERNAME/optimum-tpu.git 10 | cd optimum-tpu 11 | ``` 12 | 13 | 2. Install the package locally: 14 | ```bash 15 | python -m venv .venv 16 | source .venv/bin/activate 17 | python -m pip install . -f https://storage.googleapis.com/libtpu-releases/index.html 18 | ``` 19 | 20 | ## Development Tools 21 | 22 | The project includes a comprehensive Makefile with commands for various development tasks: 23 | 24 | ### Testing 25 | ```bash 26 | make tests # Run all the non-TGI-related tests 27 | make tgi_test # Run TGI tests with PyTorch/XLA 28 | make tgi_test_jetstream # Run TGI tests with Jetstream backend 29 | make tgi_docker_test # Run TGI integration tests in Docker 30 | ``` 31 | 32 | ### Code Quality 33 | ```bash 34 | make style # Auto-fix code style issues 35 | make style_check # Check code style without fixing 36 | ``` 37 | 38 | ### Documentation 39 | ```bash 40 | make preview_doc # Preview documentation locally 41 | ``` 42 | 43 | ### Docker Images 44 | ```bash 45 | make tpu-tgi # Build TGI Docker image 46 | make tpu-tgi-ie # Build TGI inference endpoint image 47 | make tpu-tgi-gcp # Build TGI Google Cloud image 48 | ``` 49 | 50 | ### TGI Development 51 | When working on Text Generation Inference (`/text-generation-inference` folder), you might also want to build a TGI image from scratch. To do this, refer to the manual image building section of the [serving how to guide](./howto/serving) 52 | 53 | 1. Build the standalone server: 54 | ```bash 55 | make tgi_server 56 | ``` 57 | 58 | ## Pull Request Process 59 | 60 | 1. Create a new branch: 61 | ```bash 62 | git checkout -b your-feature-name 63 | ``` 64 | 65 | 2. Make your changes 66 | 67 | 3. Run tests: 68 | ```bash 69 | make tests 70 | # Run more specialized test if needed such as make tgi_test, make tgi_test_jetstream, make tgi_docker_test 71 | make style_check 72 | ``` 73 | 74 | 4. Submit your PR with: 75 | - Clear description of changes 76 | - Test results 77 | - Documentation updates if needed 78 | 79 | ## Need Help? 80 | 81 | - Check the [documentation](https://huggingface.co/docs/optimum/tpu/overview) 82 | - Open an issue for bugs or feature requests 83 | 84 | ## License 85 | 86 | By contributing to Optimum TPU, you agree that your contributions will be licensed under the Apache License, Version 2.0. -------------------------------------------------------------------------------- /optimum/tpu/xla_mp_comm.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.managers import ListProxy 2 | from typing import List 3 | 4 | import torch.multiprocessing as mp 5 | 6 | 7 | class RootMailbox: 8 | """A simple multiprocessing mailbox to communicate between the root process and the agents.""" 9 | def __init__(self, manager: mp.Manager): 10 | self.root_bell = manager.Event() 11 | self.root_command = manager.list() 12 | self.agent_ready = manager.Event() 13 | self.output_data = manager.list() 14 | self.agent_error = manager.Event() 15 | self.agent_error.clear() 16 | 17 | def send(self, command: int, *args) -> ListProxy: 18 | """Send a command and arguments to the agents and wait for the response. 19 | 20 | Args: 21 | command (int): Command to send to the agents. 22 | *args: Arguments to send to the agents. 23 | 24 | Returns: 25 | A list containing the response from the agents. 26 | """ 27 | # First wait until agent is ready to receive commands 28 | self.agent_ready.wait() 29 | self.agent_ready.clear() 30 | 31 | self.root_command[:] = [command, *args] 32 | self.root_bell.set() 33 | # wait again until agent is ready, meaning command has been processed 34 | self.agent_ready.wait() 35 | if self.agent_error.is_set(): 36 | raise RuntimeError("Error on one of threads, stopping.") 37 | ret = self.output_data 38 | return ret 39 | 40 | 41 | class AgentMailbox: 42 | """The agent mailbox to communicate with the root process.""" 43 | def __init__(self, root_mailbox: RootMailbox): 44 | self.root_bell = root_mailbox.root_bell 45 | self.root_command = root_mailbox.root_command 46 | self.agent_ready = root_mailbox.agent_ready 47 | self.output_data = root_mailbox.output_data 48 | self.agent_error = root_mailbox.agent_error 49 | 50 | def receive(self) -> ListProxy: 51 | """Wait for a command from the root process and return it. 52 | 53 | Returns: 54 | A list containing the command and arguments from the root process. 55 | """ 56 | self.root_bell.wait() 57 | self.root_bell.clear() 58 | return self.root_command 59 | 60 | def send(self, *data): 61 | """Send the response to the root process. 62 | 63 | Args: 64 | *data: Data to send to the root process. 65 | """ 66 | self.output_data[:] = [*data] 67 | 68 | @property 69 | def command_data(self) -> tuple[int, List]: 70 | """Property helper to split command and arguments sent by the root process. 71 | 72 | Returns: 73 | A tuple containing the command and arguments. 74 | """ 75 | command = self.root_command[0] 76 | data = self.root_command[1:] 77 | return command, data 78 | -------------------------------------------------------------------------------- /docs/source/howto/gcloud_cli.mdx: -------------------------------------------------------------------------------- 1 | # Deploying and Connecting to Google TPU Instances via GCloud CLI 2 | 3 | ## Context 4 | 5 | We assume the reader has already created a Google Cloud Platform (GCP) user or organization account and an 6 | associated project. 7 | 8 | We also assume the reader to have the Google Cloud CLI installed. If not, please follow the links right after to 9 | [install](https://cloud.google.com/sdk/docs/install) and [setup](https://cloud.google.com/sdk/docs/initializing). 10 | 11 | ## Creating the initial TPU VM on GCP 12 | 13 | In order to create your initial TPU instance, you will need to provide some information: 14 | 15 | - The [GCP zone](https://cloud.google.com/tpu/docs/regions-zones) you would like to see the instance being deployed (close to the reader for development purposes, close to the end user for production, for instance) 16 | - Which kind of [TPU](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions) you would like to target 17 | - Which version of the [TPU runtime](https://cloud.google.com/tpu/docs/runtimes) you want to leverage on the instance 18 | - Custom instance name to quickly skim and refer back to the instance 19 | 20 | Overall, the end command looks like this: 21 | 22 | ```bash 23 | gcloud compute tpus tpu-vm create \ 24 | --zone= \ 25 | --accelerator-type= \ 26 | --version= 27 | ``` 28 | 29 | ### Deploying a TPU v5litepod-8 instance 30 | 31 | In our case, we will be deploying a `v5litepod-8` instance name `optimum-tpu-get-started` 32 | in the GCP region `us-west4-a` using the latest `v2-alpha-tpuv5-lite` runtime version. 33 | 34 | Of course, feel free to adjust all these parameters to the one that match with your usage and quotas. 35 | 36 | Before creating the instance, please make sure to install `gcloud alpha component` as it is required to be able to 37 | target TPUv5 VMs: `gcloud components install alpha` 38 | 39 | ```bash 40 | gcloud alpha compute tpus tpu-vm create optimum-tpu-get-started \ 41 | --zone=us-west4-a \ 42 | --accelerator-type=v5litepod-8 \ 43 | --version=v2-alpha-tpuv5-lite 44 | ``` 45 | 46 | ## Connecting to the instance via ssh 47 | 48 | ```bash 49 | gcloud compute tpus tpu-vm ssh --zone= 50 | $ > 51 | ``` 52 | 53 | In the example above deploying v5litepod-8 it would be something like: 54 | 55 | ```bash 56 | gcloud compute tpus tpu-vm ssh optimum-tpu-get-started --zone=us-west4-a 57 | $ > 58 | ``` 59 | 60 | ## Other useful commands 61 | 62 | This is used to get information about the tpu-vm for example its external IP: 63 | ```bash 64 | gcloud compute tpus tpu-vm describe --zone= 65 | ``` 66 | 67 | ## Next steps 68 | - If you wish to train your own model, you can now [install optimum-tpu](../installation) 69 | - If you wish do to serving, you can look at our [serving tutorial](../tutorials/inference_on_tpu) -------------------------------------------------------------------------------- /optimum/tpu/static_cache_xla.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple 2 | 3 | import torch 4 | from transformers import StaticCache 5 | 6 | 7 | class StaticCacheXla(StaticCache): 8 | def update( 9 | self, 10 | key_states: torch.Tensor, 11 | value_states: torch.Tensor, 12 | layer_idx: int, 13 | cache_kwargs: Optional[Dict[str, Any]] = None, 14 | ) -> Tuple[torch.Tensor, torch.Tensor]: 15 | """ 16 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. 17 | It is VERY important to index using a tensor, otherwise you introduce a copy to the device. 18 | 19 | Parameters: 20 | key_states (`torch.Tensor`): 21 | The new key states to cache. 22 | value_states (`torch.Tensor`): 23 | The new value states to cache. 24 | layer_idx (`int`): 25 | The index of the layer to cache the states for. 26 | cache_kwargs (`Dict[str, Any]`, `optional`): 27 | Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input 28 | to know how where to write in the cache. 29 | 30 | Return: 31 | A tuple containing the updated key and value states. 32 | """ 33 | cache_position = cache_kwargs.get("cache_position") 34 | k_out = self.key_cache[layer_idx] 35 | v_out = self.value_cache[layer_idx] 36 | 37 | # `index_copy_(dim, index, source)` functions similarly to `tensor[index] = source`, 38 | # but it is used for better generality and it uses less memory on XLA. 39 | # For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html 40 | k_out.index_copy_(2, cache_position, key_states) 41 | v_out.index_copy_(2, cache_position, value_states) 42 | 43 | return k_out, v_out 44 | 45 | 46 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 47 | """Returns the sequence length of the cached states that were seen by the model.""" 48 | # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's 49 | # limit the check to the first batch member and head dimension. 50 | # TODO: deprecate this function in favor of `cache_position` 51 | key_cache = self.key_cache[layer_idx] 52 | device = key_cache.device 53 | 54 | # index_select(dim, index) performs the same operation as item = tensor[..., index, ...] 55 | # but it is used for better generality and it uses less memory on XLA. 56 | # For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html 57 | item = key_cache.index_select(0, torch.tensor(0, device=device)) 58 | head = item.index_select(1, torch.tensor(0, device=device)) 59 | 60 | return head.any(dim=-1).sum() 61 | -------------------------------------------------------------------------------- /docs/source/reference/fsdp_v2.mdx: -------------------------------------------------------------------------------- 1 | # Fully Sharded Data Parallel (FSDP) v2 2 | 3 | ## Overview 4 | 5 | When fine-tuning Large Language Models (LLMs) on TPUs, model sharding across devices becomes essential for memory efficiency and improved training performance. The `optimum.tpu.fsdp_v2` module provides utilities for implementing Fully Sharded Data Parallel training using SPMD (Single Program Multiple Data) specifically optimized for TPU devices. 6 | 7 | ## FSDP_v2 Features 8 | 9 | - Model weight sharding across TPU devices 10 | - Gradient checkpointing support 11 | - Automatic configuration for common model architectures 12 | - Integration with PyTorch/XLA's SPMD implementation 13 | 14 | ## Basic Usage 15 | 16 | Here's how to enable and configure FSDP_v2 for your training: 17 | 18 | ```python 19 | from optimum.tpu import fsdp_v2 20 | from transformers import AutoModelForCausalLM, AutoTokenizer 21 | import torch 22 | 23 | # Enable FSDP_v2 24 | fsdp_v2.use_fsdp_v2() 25 | 26 | # Load model and tokenizer 27 | model_id = "meta-llama/Llama-2-7b" 28 | tokenizer = AutoTokenizer.from_pretrained(model_id) 29 | model = AutoModelForCausalLM.from_pretrained( 30 | model_id, 31 | torch_dtype=torch.bfloat16 32 | ) 33 | 34 | # Get FSDP training configuration 35 | fsdp_args = fsdp_v2.get_fsdp_training_args(model) 36 | ``` 37 | 38 | ## Configuration Options 39 | 40 | The `get_fsdp_training_args()` function returns a dictionary with a model-specific configuration such as: 41 | 42 | ```python 43 | { 44 | 'fsdp': 'full_shard', 45 | 'fsdp_config': { 46 | 'transformer_layer_cls_to_wrap': ['LlamaDecoderLayer'], # Model-specific 47 | 'xla': True, 48 | 'xla_fsdp_v2': True, 49 | 'xla_fsdp_grad_ckpt': True 50 | } 51 | } 52 | ``` 53 | 54 | ### Key Parameters 55 | 56 | - `transformer_layer_cls_to_wrap`: Specifies which model layers to wrap with FSDP 57 | - `xla`: Enables XLA optimization 58 | - `xla_fsdp_v2`: Activates FSDP_v2 implementation 59 | - `xla_fsdp_grad_ckpt`: Enables gradient checkpointing for memory efficiency 60 | 61 | ## Advanced Usage 62 | 63 | ### Custom Layer Wrapping 64 | 65 | You can customize which layers get wrapped with FSDP: 66 | 67 | ```python 68 | custom_fsdp_args = fsdp_v2.get_fsdp_training_args( 69 | model, 70 | layer_cls_to_wrap=['CustomTransformerLayer'] 71 | ) 72 | ``` 73 | 74 | ### Integration with Transformers Trainer 75 | 76 | FSDP_v2 configuration can be directly used with the Transformers Trainer: 77 | 78 | ```python 79 | from transformers import Trainer, TrainingArguments 80 | # Or for instruction fine-tuning: 81 | # from trl import SFTTrainer 82 | 83 | trainer = Trainer( # or SFTTrainer 84 | model=model, 85 | args=TrainingArguments(**fsdp_args), # Unpack FSDP configuration 86 | train_dataset=dataset, 87 | ... 88 | ) 89 | ``` 90 | 91 | ## Next steps 92 | - You can look our [example notebooks](../howto/more_examples) for best practice on training with optimum-tpu 93 | - For more details on PyTorch/XLA's FSDP implementation, refer to the [official documentation](https://pytorch.org/xla/master/#fully-sharded-data-parallel-via-spmd). -------------------------------------------------------------------------------- /optimum/tpu/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """TpuModelForXXX classes for inference on TPU devices using the same API as 16 | Transformers.""" 17 | 18 | from os import PathLike, environ 19 | from typing import Any 20 | 21 | from loguru import logger 22 | from transformers import AutoConfig 23 | from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM 24 | 25 | 26 | def config_name_to_class(pretrained_model_name_or_path: str): 27 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path) 28 | if config.model_type == "gemma": 29 | from .modeling_gemma import GemmaForCausalLM 30 | 31 | return GemmaForCausalLM 32 | if config.model_type == "llama": 33 | from .modeling_llama import LlamaForCausalLM 34 | 35 | return LlamaForCausalLM 36 | if config.model_type == "mistral": 37 | from .modeling_mistral import MistralForCausalLM 38 | 39 | return MistralForCausalLM 40 | return BaseAutoModelForCausalLM 41 | 42 | 43 | class AutoModelForCausalLM(BaseAutoModelForCausalLM): 44 | 45 | @classmethod 46 | def from_pretrained( 47 | cls, 48 | pretrained_model_name_or_path: str | PathLike[str], 49 | task: str = None, 50 | batch_size: int = None, 51 | sequence_length: int = None, 52 | *model_args: Any, 53 | **kwargs: Any, 54 | ): 55 | if "PJRT_DEVICE" not in environ: 56 | logger.info("PJRT_DEVICE environment variable not found. Setting it to 'TPU'.") 57 | environ["PJRT_DEVICE"] = "TPU" 58 | if "DBG_DEVICE" in environ: 59 | device = environ["DBG_DEVICE"] 60 | logger.debug(f"Device set to: {device}") 61 | else: 62 | device = "xla" 63 | cls = config_name_to_class(pretrained_model_name_or_path) 64 | model = cls.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 65 | model.to(device) 66 | 67 | # Update config with specific data) 68 | if task is not None or getattr(model.config, "task", None) is None: 69 | model.config.task = task 70 | if batch_size is not None or getattr(model.config, "batch_size", None) is None: 71 | model.config.batch_size = batch_size 72 | if sequence_length is not None or getattr(model.config, "sequence_length", None) is None: 73 | model.config.sequence_length = sequence_length 74 | # Do eval 75 | model.eval() 76 | 77 | return model 78 | -------------------------------------------------------------------------------- /.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml: -------------------------------------------------------------------------------- 1 | name: Optimum TPU / Test TGI on TPU (slow tests) / Jetstream Pytorch 2 | 3 | on: 4 | # schedule: 5 | # - cron: '0 3 * * *' # run at 3 AM UTC 6 | # This can be used to allow manually triggering nightlies from the web interface 7 | workflow_dispatch: 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | do-the-job: 15 | name: Build and Run slow tests 16 | runs-on: 17 | group: gcp-ct5lp-hightpu-8t 18 | container: 19 | image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm 20 | options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache 21 | env: 22 | PJRT_DEVICE: TPU 23 | HF_TOKEN: ${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} 24 | HF_HUB_CACHE: /mnt/hf_cache/cache_huggingface 25 | JETSTREAM_PT: 1 26 | steps: 27 | - name: Checkout 28 | uses: actions/checkout@v4 29 | 30 | - name: Build and install Jetstream Pytorch TGI 31 | run: | 32 | make jetstream_requirements tgi_server test_installs 33 | find text-generation-inference/ -name "text_generation_server-*whl" -exec python -m pip install {} \; 34 | - name: Run TGI Jetstream Pytorch - Llama 35 | run: | 36 | python -m \ 37 | pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Llama" 38 | - name: Run TGI Jetstream Pytorch - Gemma 39 | run: | 40 | python -m \ 41 | pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and gemma" 42 | - name: Run TGI Jetstream Pytorch - Mixtral greedy 43 | run: | 44 | python -m \ 45 | pytest -sv text-generation-inference/tests/test_decode_jetstream.py --runslow -m jetstream -k "slow and Mixtral and greedy" 46 | - name: Run TGI Jetstream Pytorch - Quantization Mixtral 47 | run: | 48 | python -m \ 49 | pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Mixtral" 50 | - name: Run TGI Jetstream Pytorch - Quantization Llama-3 8B 51 | run: | 52 | python -m \ 53 | pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-8B" 54 | - name: Run TGI Jetstream Pytorch - Quantization Llama 3 70B 55 | run: | 56 | python -m \ 57 | pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3-70B" 58 | - name: Run TGI Jetstream Pytorch - Quantization Llama 3.3 70B 59 | run: | 60 | python -m \ 61 | pytest -sv text-generation-inference/tests/test_decode_jetstream_quant.py --runslow -m jetstream -k "slow and Llama-3.3-70B" 62 | - name: Run TGI Jetstream Pytorch - Other tests 63 | run: | 64 | python -m \ 65 | pytest -sv text-generation-inference/tests --runslow -m jetstream -k "not decode" 66 | -------------------------------------------------------------------------------- /text-generation-inference/tests/test_generator_slot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from text_generation_server.pb.generate_pb2 import Request 4 | from transformers import AutoTokenizer, GenerationConfig 5 | 6 | 7 | TOKENIZERS = ["NousResearch/Llama-2-7b-hf", "openai-community/gpt2"] 8 | 9 | # Defining this global variable will parametrize all tests in this file 10 | pytestmark = pytest.mark.parametrize( 11 | "input_text, generated_text", 12 | [ 13 | [ 14 | "It was a bright cold day in April, and the clocks were striking thirteen.", 15 | " Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind," 16 | " slipped quickly through the glass doors of Victory Mansions, though not quickly enough" 17 | " to prevent a swirl of gritty dust from entering along with him.", 18 | ], 19 | ["This sentence is written in chinese:", "我很感谢你的热情"], 20 | ["Some text might contain a lot of emojis like 😃", "😍💪 👉 👀"], 21 | ], 22 | ids=["spaces", "chinese-utf8", "emojis"], 23 | ) 24 | 25 | 26 | @pytest.fixture(params=TOKENIZERS) 27 | def tokenizer(request): 28 | t = AutoTokenizer.from_pretrained(request.param) 29 | t.padding_side = "left" 30 | t.pad_token_id = t.eos_token_id 31 | return t 32 | 33 | 34 | def _test_decode_streaming(slot, return_tensors, tokenizer, input_text, generated_text): 35 | request = Request(id=0, inputs=input_text) 36 | slot.assign(0, request, GenerationConfig()) 37 | 38 | inputs = tokenizer(input_text, padding="max_length", max_length=len(input_text) + 1, return_tensors=return_tensors) 39 | input_ids = inputs["input_ids"][0] 40 | generated_tokens = tokenizer(generated_text, add_special_tokens=False)["input_ids"] 41 | 42 | # We need to regenerate the full text as the tokenizer might change it (extra spaces might be added) 43 | all_input_ids = np.concatenate([input_ids, generated_tokens]) 44 | full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True) 45 | regenerated_text = full_text[len(input_text) :] 46 | 47 | # Initialize the slot with the inputs 48 | slot.reset(input_ids, selector=None) 49 | 50 | assert slot.generated_tokens == 0 51 | 52 | # Simulate an iterative generation (i.e. don't call select and use known tokens instead) 53 | decoded_text = "" 54 | for i in range(len(generated_tokens)): 55 | text = slot.append(generated_tokens[i]) 56 | assert slot.generated_tokens == i + 1 57 | decoded_text += text 58 | 59 | assert decoded_text == regenerated_text 60 | 61 | 62 | @pytest.mark.jetstream 63 | def test_decode_streaming_jetstream(tokenizer, input_text, generated_text): 64 | from text_generation_server.jetstream_pt_support.generator import Slot 65 | 66 | slot = Slot(0, tokenizer) 67 | _test_decode_streaming(slot, "pt", tokenizer, input_text, generated_text) 68 | 69 | @pytest.mark.torch_xla 70 | def test_decode_streaming(tokenizer, input_text, generated_text): 71 | from text_generation_server.generator import Slot 72 | 73 | # Note: device used is cpu to make it faster 74 | slot = Slot(0, tokenizer, "cpu") 75 | _test_decode_streaming(slot, "pt", tokenizer, input_text, generated_text) 76 | -------------------------------------------------------------------------------- /docs/source/tutorials/tpu_setup.mdx: -------------------------------------------------------------------------------- 1 | # First TPU Setup on Google Cloud 2 | 3 | This guide walks you through setting up and accessing your first TPU instance on Google Cloud Platform (GCP). 4 | 5 | ## Prerequisites 6 | 7 | Before you begin, ensure you have: 8 | - A Google Cloud account 9 | - Billing enabled on your account 10 | - Basic familiarity with cloud consoles 11 | 12 | ## Step 1: Enable TPU Access 13 | 14 | 1. Navigate to the TPU dashboard at: https://console.cloud.google.com/compute/tpus 15 | - Note: You will need to enable the TPU API if you haven't already 16 | - A valid billing account must be linked to your project 17 | 18 | 2. If prompted, enable the TPU API for your project 19 | 20 | ![TPU Dashboard](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/gcp_tpu_dashboard.png) 21 | 22 | ## Step 2: Create your TPU Instance 23 | 24 | Click the "Create" button to setup your TPU instance. 25 | 26 | ![TPU Setup](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/gcp_create_tpu_config.png) 27 | 28 | ### Region Selection 29 | 30 | 1. Review available regions and zones for TPUs: https://cloud.google.com/tpu/docs/regions-zones 31 | 2. For this example, we will use `us-west-4a` zone 32 | - Important: TPU availability may vary by region 33 | - Tips: Choose a region close to your primary usage location 34 | 35 | ### TPU Configuration 36 | 37 | 1. Select TPU type: 38 | - We will use a TPU `v5e-8` (corresponds to a v5litepod8). This is a TPU node containing 8 v5e TPU chips 39 | - For detailed specifications about TPU types, refer to our [TPU hardware types documentation](../conceptual_guides/tpu_hardware_support) 40 | 41 | 2. Choose a runtime: 42 | - Select `v2-alpha-tpuv5-lite` runtime 43 | - This runtime is optimized for TPU v5e 44 | - More runtime information on runtime can be found in at recommended runtime for TPU section in our [TPU hardware page](../conceptual_guides/tpu_hardware_support) 45 | 46 | ## Step 3: Access Your TPU 47 | 48 | After creation, your TPU instance should be accessible by ssh 49 | 50 | 1. Access your TPU: 51 | - Click the SSH button in the console for immediate terminal access 52 | 53 | ![TPU SSH](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/gcp_ssh_tpu.png) 54 | 55 | 2. For permanent SSH access: 56 | - Add your SSH keys following the guide at: https://cloud.google.com/compute/docs/connect/add-ssh-keys 57 | - This enables more convenient access for future sessions 58 | - You can also look at the ssh section in our [guide about the gcloud cli](../howto/gcloud_cli) 59 | 60 | ## Next Steps 61 | 62 | Now that you have a working TPU environment, you can start using it for AI workloads. We offer two main paths depending on your use case: 63 | 64 | ### AI Inference and Training Tutorials 65 | 66 | 1. **Model Serving on TPU** 67 | - Follow our serving tutorial: [First Model Serving on TPU](../tutorials/inference_on_tpu) 68 | - Learn how to deploy and serve ML models efficiently on TPU 69 | 70 | 2. **Model Training on TPU** 71 | - Start with our training guide: [First Model Training on TPU](../tutorials/training_on_tpu) 72 | - Learn how to start training ML models on TPU 73 | 74 | Choose the tutorial that best matches your immediate needs: 75 | - For deploying existing models, start with our [model serving tutorial](../tutorials/inference_on_tpu) 76 | - For training new models, begin with our [model training tutorial](../tutorials/training_on_tpu) 77 | -------------------------------------------------------------------------------- /docs/source/howto/serving.mdx: -------------------------------------------------------------------------------- 1 | # Deploying a Text-Generation Inference server (TGI) on a Google Cloud TPU instance 2 | 3 | Text-Generation-Inference (TGI) enables serving Large Language Models (LLMs) on TPUs, with Optimum TPU delivering a specialized TGI runtime that's fully optimized for TPU hardware. 4 | 5 | TGI also offers an openAI-compatible API, making it easy to integrate with numerous tools. 6 | 7 | For a list of supported models, check the [Supported Models page](../supported-architectures). 8 | 9 | ## Deploy TGI on a Cloud TPU Instance 10 | 11 | This guide assumes you have a Cloud TPU instance running. If not, please refer to our [deployment guide](../tutorials/tpu_setup). 12 | 13 | You have two options for deploying TGI: 14 | 1. Use our pre-built TGI image (recommended) 15 | 2. Build the image manually for the latest features 16 | 17 | ### Option 1: Using the Pre-built Image 18 | 19 | The optimum-tpu image is available at `ghcr.io/huggingface/optimum-tpu:v0.2.3-tgi`. Please look at [optimum-tpu container documentation](../optimum_container) for the latest TGI image. The [tutorial on serving](../tutorials/inference_on_tpu) also walks you through how to start the TGI container from a pre-built image. Here's how to deploy it: 20 | 21 | ```bash 22 | docker run -p 8080:80 \ 23 | --shm-size 16GB \ 24 | --privileged \ 25 | --net host \ 26 | -e LOG_LEVEL=text_generation_router=debug \ 27 | -v ~/hf_data:/data \ 28 | -e HF_TOKEN= \ 29 | ghcr.io/huggingface/optimum-tpu:v0.2.3-tgi \ 30 | --model-id google/gemma-2b-it \ 31 | --max-input-length 512 \ 32 | --max-total-tokens 1024 \ 33 | --max-batch-prefill-tokens 512 \ 34 | --max-batch-total-tokens 1024 35 | ``` 36 | 37 | 38 | You need to replace with a HuggingFace access token that you can get [here](https://huggingface.co/settings/tokens) 39 | 40 | 41 | 42 | If you already logged in via `huggingface-cli login` then you can set HF_TOKEN=$(cat ~/.cache/huggingface/token) for more convinence 43 | 44 | 45 | You can also use the GCP-provided image as referenced in the [optimum-tpu container page](../optimum_container) 46 | 47 | ### Option 2: Manual Image Building 48 | 49 | For the latest features (main branch of optimum-tpu) or custom modifications, build the image yourself: 50 | 51 | 1. Clone the repository: 52 | ```bash 53 | git clone https://github.com/huggingface/optimum-tpu.git 54 | ``` 55 | 56 | 2. Build the image: 57 | ```bash 58 | make tpu-tgi 59 | ``` 60 | 61 | 3. Run the container: 62 | ```bash 63 | HF_TOKEN= 64 | MODEL_ID=google/gemma-2b-it 65 | 66 | sudo docker run --net=host \ 67 | --privileged \ 68 | -v $(pwd)/data:/data \ 69 | -e HF_TOKEN=${HF_TOKEN} \ 70 | huggingface/optimum-tpu:latest \ 71 | --model-id ${MODEL_ID} \ 72 | --max-concurrent-requests 4 \ 73 | --max-input-length 32 \ 74 | --max-total-tokens 64 \ 75 | --max-batch-size 1 76 | ``` 77 | 78 | ## Executing requests against the service 79 | 80 | You can query the model using either the `/generate` or `/generate_stream` routes: 81 | 82 | 83 | ```bash 84 | curl localhost/generate \ 85 | -X POST \ 86 | -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ 87 | -H 'Content-Type: application/json' 88 | ``` 89 | 90 | ```bash 91 | curl localhost/generate_stream \ 92 | -X POST \ 93 | -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ 94 | -H 'Content-Type: application/json' 95 | ``` -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/jetstream_pt_support/logits_process.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from transformers import GenerationConfig 7 | 8 | 9 | @dataclass 10 | class FusedLogitsWarper: 11 | """ 12 | A class that performs top-k then top-p filtering, optionally applying a temperature. 13 | 14 | Top-k filtering only keeps the `k` tokens with the best scores. 15 | 16 | Top-p filtering only keeps the top tokens whose cumulated probability is above `p`. 17 | 18 | The filtered tokens are returned as a list of indices, along with the corresponding subset of 19 | the original logits. 20 | 21 | If only top-k filtering is active, the filtered tokens are sorted by descending order. 22 | 23 | If top-p filtering is active, the filtered tokens are sorted by ascending order. 24 | 25 | Args: 26 | temperature (`float`): 27 | Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases 28 | randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely 29 | token. 30 | top_k (`int`): 31 | The number of highest probability vocabulary tokens to keep for top-k-filtering. 32 | top_p (`float`): 33 | If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or 34 | higher are kept for generation. 35 | """ 36 | 37 | temperature: float = 1.0 38 | top_k: int = 0 39 | top_p: float = 1.0 40 | 41 | @classmethod 42 | def from_config(cls, generation_config: GenerationConfig) -> "FusedLogitsWarper": 43 | """Instantiate a fused warper from a generation configuration. 44 | 45 | Args: 46 | generation_config (`~transformers.generation.GenerationConfig`): 47 | The generation configuration to be used as base parametrization for the fused warper. 48 | 49 | Returns: 50 | a `FusedLogitsWarper` or None if neither top-k nor top-p are configured. 51 | """ 52 | return cls(generation_config.temperature, generation_config.top_k, generation_config.top_p) 53 | 54 | def __call__(self, logits: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: 55 | if self.temperature != 1.0: 56 | logits = logits / self.temperature 57 | 58 | do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1] 59 | do_top_p = self.top_p < 1.0 and self.top_p > 0.0 60 | 61 | if do_top_k: 62 | sorted_indices = jnp.argsort(logits, axis=-1)[..., ::-1][:, : self.top_k] 63 | sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1) 64 | else: 65 | sorted_indices = jnp.argsort(logits, axis=-1) 66 | sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1) 67 | 68 | if do_top_p: 69 | if do_top_k: 70 | # logits have been sorted in descending order, so we need to flip them 71 | sorted_logits = jnp.flip(sorted_logits, axis=-1) 72 | sorted_indices = jnp.flip(sorted_indices, axis=-1) 73 | # We always keep the best logits and those whose cumulative probability is strictly higher than top_p 74 | cum_probs = jax.nn.softmax(sorted_logits, axis=-1).cumsum(axis=-1) 75 | keep_mask = cum_probs > (1 - self.top_p) 76 | keep_mask = keep_mask.at[:, -1].set(True) 77 | # Set rejected logits to -inf so that they are ignored in downstream comparisons 78 | sorted_logits = jnp.where(keep_mask, sorted_logits, float("-Inf")) 79 | 80 | return sorted_logits, sorted_indices 81 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | [project] 15 | name = "optimum-tpu" 16 | dynamic = ["version"] 17 | authors=[ 18 | { name = "HuggingFace Inc. Machine Learning Optimization Team", email = "hardware@huggingface.co"} 19 | ] 20 | description = "Optimum TPU is the interface between the Hugging Face Transformers library and Google Cloud TPU devices." 21 | readme = "README.md" 22 | license = {file = "LICENSE"} 23 | classifiers = [ 24 | "Development Status :: 2 - Pre-Alpha", 25 | "License :: OSI Approved :: Apache Software License", 26 | "Intended Audience :: Developers", 27 | "Intended Audience :: Education", 28 | "Intended Audience :: Science/Research", 29 | "Operating System :: OS Independent", 30 | "Programming Language :: Python :: 3.10", 31 | "Programming Language :: Python :: 3.11", 32 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 33 | ] 34 | keywords = [ 35 | "transformers", 36 | "fine-tuning", 37 | "inference", 38 | "tpu", 39 | "cloud-tpu", 40 | "gcp", 41 | "google-cloud" 42 | ] 43 | 44 | dependencies = [ 45 | "transformers == 4.46.3", 46 | "torch == 2.5.1", 47 | "torch-xla[tpu] == 2.5.1", 48 | 'typer == 0.6.1', 49 | "loguru == 0.6.0", 50 | "sentencepiece == 0.2.0", 51 | ] 52 | 53 | [tool.setuptools_scm] 54 | 55 | [build-system] 56 | requires = ["setuptools>=64", "setuptools_scm>=8"] 57 | build-backend = "setuptools.build_meta" 58 | 59 | [project.optional-dependencies] 60 | tests = ["pytest", "safetensors"] 61 | quality = ["black", "ruff", "isort"] 62 | # Jetstream/Pytorch support is experimental for now, it needs to be installed manually. 63 | # Pallas is pulled because it will install a compatible version of jax[tpu]. 64 | jetstream-pt = [ 65 | "jetstream-pt", 66 | "torch-xla[pallas] == 2.5.1" 67 | ] 68 | 69 | [project.urls] 70 | Homepage = "https://hf.co/hardware" 71 | Documentation = "https://hf.co/docs/optimum/tpu" 72 | Repository = "https://github.com/huggingface/optimum-tpu" 73 | Issues = "https://github.com/huggingface/optimum-tpu/issues" 74 | 75 | [tool.setuptools.packages.find] 76 | include = ["optimum.tpu*"] 77 | 78 | [tool.black] 79 | line-length = 119 80 | target-version = ['py38'] 81 | extend-exclude = '.ipynb' 82 | 83 | [tool.ruff] 84 | # Never enforce `E501` (line length violations). 85 | lint.ignore = ["C901", "E501", "E741", "W605"] 86 | lint.select = ["C", "E", "F", "I", "W"] 87 | line-length = 119 88 | 89 | # Ignore import violations in all `__init__.py` files. 90 | [tool.ruff.lint.per-file-ignores] 91 | "__init__.py" = ["E402", "F401", "F403", "F811"] 92 | 93 | [tool.ruff.lint.isort] 94 | lines-after-imports = 2 95 | known-first-party = ["optimum.tpu"] 96 | 97 | [tool.pytest.ini_options] 98 | markers = [ 99 | "is_staging_test", 100 | ] 101 | filterwarnings = [ 102 | "ignore:Some donated", 103 | "ignore:The given NumPy array is not writable", 104 | "ignore:`do_sample` is set", 105 | "ignore:Device capability of jax", 106 | "ignore:`tensorflow` can conflict", 107 | ] 108 | 109 | [project.scripts] 110 | optimum-tpu = "optimum.tpu.cli:app" 111 | -------------------------------------------------------------------------------- /docs/source/reference/tgi_advanced_options.mdx: -------------------------------------------------------------------------------- 1 | # TGI Configuration Reference Guide 2 | 3 | ## Required Configuration 4 | 5 | ### Required Environment Variables 6 | - `HF_TOKEN`: HuggingFace authentication token 7 | 8 | ### Required Command Line Arguments 9 | **docker specific parameters** 10 | - `--shm-size 16GB`: Shared memory allocation 11 | - `--privileged`: Enable privileged container mode 12 | - `--net host`: Uses host network mode 13 | 14 | Those are needed to run a TPU container so that the docker container can properly access the TPU hardware 15 | 16 | **TGI specific parameters** 17 | - `--model-id`: Model identifier to load from the HuggingFace hub 18 | 19 | Those are parameters used by TGI and optimum-TPU to configure the server behavior. 20 | 21 | ## Optional Configuration 22 | 23 | ### Optional Environment Variables 24 | - `JETSTREAM_PT_DISABLE`: Disable Jetstream PyTorch backend 25 | - `QUANTIZATION`: Enable int8 quantization 26 | - `MAX_BATCH_SIZE`: Set batch processing size, that is **static** on TPUs 27 | - `LOG_LEVEL`: Set logging verbosity (useful for debugging). It can be set to info, debug or a comma separated list of attribute such text_generation_launcher,text_generation_router=debug 28 | - `SKIP_WARMUP`: Skip model warmup phase 29 | 30 | **Note on warmup:** 31 | - TGI performs warmup to compile TPU operations for optimal performance 32 | - For production use, never use `SKIP_WARMUP=1`; you can however use the parameters for debugging purposes to speed up model loading at the cost of slow model inference 33 | 34 | You can view more options in the [TGI documentation](https://huggingface.co/docs/text-generation-inference/reference/launcher). Not all parameters might be compatible with TPUs (for example, all the CUDA-specific parameters) 35 | 36 | 37 | TIP for TGI: you can pass most parameters to TGI as docker environment variables or docker arguments. So you can pass `--model-id google/gemma-2b-it` or `-e MODEL_ID=google/gemma-2b-it` to the `docker run` command 38 | 39 | 40 | ### Optional Command Line Arguments 41 | - `--max-input-length`: Maximum input sequence length 42 | - `--max-total-tokens`: Maximum combined input/output tokens 43 | - `--max-batch-prefill-tokens`: Maximum tokens for batch processing 44 | - `--max-batch-total-tokens`: Maximum total tokens in batch 45 | 46 | You can view more options in the [TGI documentation](https://huggingface.co/docs/text-generation-inference/reference/launcher). Not all parameters might be compatible with TPUs (for example, all the CUDA-specific parameters) 47 | 48 | ### Docker Requirements 49 | When running TGI inside a container (recommended), the container should be started with: 50 | - Privileged mode for TPU access 51 | - Shared memory allocation (16GB recommended) 52 | - Host IPC settings 53 | 54 | ## Example Command 55 | Here's a complete example showing all major configuration options: 56 | 57 | ```bash 58 | docker run -p 8080:80 \ 59 | --shm-size 16GB \ 60 | --privileged \ 61 | --net host \ 62 | -e QUANTIZATION=1 \ 63 | -e MAX_BATCH_SIZE=2 \ 64 | -e LOG_LEVEL=text_generation_router=debug \ 65 | -v ~/hf_data:/data \ 66 | -e HF_TOKEN= \ 67 | ghcr.io/huggingface/optimum-tpu:v0.2.3-tgi \ 68 | --model-id google/gemma-2b-it \ 69 | --max-input-length 512 \ 70 | --max-total-tokens 1024 \ 71 | --max-batch-prefill-tokens 512 \ 72 | --max-batch-total-tokens 1024 73 | ``` 74 | 75 | 76 | You need to replace with a HuggingFace access token that you can get [here](https://huggingface.co/settings/tokens) 77 | 78 | 79 | 80 | If you already logged in via `huggingface-cli login`, then you can set HF_TOKEN=$(cat ~/.cache/huggingface/token) for more convenience 81 | 82 | 83 | ## Additional Resources 84 | - [TGI Documentation](https://huggingface.co/docs/text-generation-inference) -------------------------------------------------------------------------------- /optimum/tpu/generation/logits_process.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | 4 | import torch 5 | from transformers import GenerationConfig 6 | 7 | 8 | @dataclass 9 | class FusedLogitsWarper: 10 | """ 11 | A class that performs top-k then top-p filtering, optionally applying a temperature. 12 | 13 | Top-k filtering only keeps the `k` tokens with the best scores. 14 | 15 | Top-p filtering only keeps the top tokens whose cumulated probability is above `p`. 16 | 17 | The filtered tokens are returned as a list of indices, along with the corresponding subset of 18 | the original logits. 19 | 20 | If only top-k filtering is active, the filtered tokens are sorted by descending order. 21 | 22 | If top-p filtering is active, the filtered tokens are sorted by ascending order. 23 | 24 | Args: 25 | temperature (`float`): 26 | Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases 27 | randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely 28 | token. 29 | top_k (`int`): 30 | The number of highest probability vocabulary tokens to keep for top-k-filtering. 31 | top_p (`float`): 32 | If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or 33 | higher are kept for generation. 34 | """ 35 | 36 | temperature: float = 1.0 37 | top_k: int = 0 38 | top_p: float = 1.0 39 | 40 | @classmethod 41 | def from_config(cls, generation_config: GenerationConfig) -> "FusedLogitsWarper": 42 | """Instantiate a fused warper from a generation configuration. 43 | 44 | Args: 45 | generation_config (`~transformers.generation.GenerationConfig`): 46 | The generation configuration to be used as base parametrization for the fused warper. 47 | 48 | Returns: 49 | a `FusedLogitsWarper` or None if neither top-k nor top-p are configured. 50 | """ 51 | return cls(generation_config.temperature, generation_config.top_k, generation_config.top_p) 52 | 53 | def __call__(self, logits: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: 54 | if self.temperature != 1.0: 55 | logits = logits / self.temperature 56 | 57 | do_top_k = self.top_k > 0 and self.top_k < logits.shape[-1] 58 | do_top_p = self.top_p < 1.0 and self.top_p > 0.0 59 | 60 | if do_top_k: 61 | sorted_logits, sorted_indices = torch.topk(logits, self.top_k) 62 | else: 63 | # Warning: not applying top-k filtering leads to this very slow sort operation 64 | sorted_logits, sorted_indices = torch.sort(logits) 65 | 66 | if do_top_p: 67 | if do_top_k: 68 | # logits have been sorted in descending order, so we need to flip them 69 | sorted_logits = torch.flip(sorted_logits, [-1]) 70 | sorted_indices = torch.flip(sorted_indices, [-1]) 71 | # We always keep the best logits and those whose cumulative probability is strictly higher than top_p 72 | cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 73 | keep_mask = cum_probs > (1 - self.top_p) 74 | keep_mask[:, -1] = True 75 | # Set rejected logits to -inf so that they are ignored in downstream comparisons 76 | sorted_logits[~keep_mask] = float("-Inf") 77 | # Clip the [batch_size, vocab_size] logits tensor to speed-up downstream ops 78 | keep_by_batch = torch.sum(keep_mask, dim=-1) 79 | keep = torch.amax(keep_by_batch) 80 | sorted_logits = sorted_logits[:, -keep:] 81 | sorted_indices = sorted_indices[:, -keep:] 82 | 83 | return sorted_logits, sorted_indices 84 | -------------------------------------------------------------------------------- /.github/workflows/tpu-tgi-release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | # This can be used to allow manually triggering this workflow from the web interface 7 | workflow_dispatch: 8 | 9 | jobs: 10 | docker: 11 | name: Push TGI Docker container to Docker Hub and Github Registry 12 | runs-on: ubuntu-latest 13 | permissions: 14 | packages: write 15 | contents: read 16 | attestations: write 17 | id-token: write 18 | env: 19 | REGISTRY: ghcr.io 20 | IMAGE_NAME: ${{ github.repository }} 21 | steps: 22 | - name: Check out the repo 23 | uses: actions/checkout@v4 24 | 25 | - name: Log in to Docker Hub 26 | uses: docker/login-action@v3 27 | with: 28 | username: ${{ secrets.DOCKERHUB_USERNAME }} 29 | password: ${{ secrets.DOCKERHUB_PASSWORD }} 30 | 31 | - name: Log in to the Container registry 32 | uses: docker/login-action@v3 33 | with: 34 | registry: ${{ env.REGISTRY }} 35 | username: ${{ github.actor }} 36 | password: ${{ secrets.GITHUB_TOKEN }} 37 | 38 | - name: Extract metadata (tags, labels) for Docker TGI 39 | id: meta 40 | uses: docker/metadata-action@v5 41 | with: 42 | images: | 43 | ghcr.io/${{ env.IMAGE_NAME }} 44 | ${{ env.IMAGE_NAME}} 45 | flavor: | 46 | latest=auto 47 | prefix= 48 | suffix=-tgi 49 | 50 | - name: Extract metadata (tags, labels) for Docker TGI-IE 51 | id: meta-ie 52 | uses: docker/metadata-action@v5 53 | with: 54 | images: | 55 | ghcr.io/${{ env.IMAGE_NAME }} 56 | ${{ env.IMAGE_NAME}} 57 | flavor: | 58 | latest=auto 59 | prefix= 60 | suffix=-tgi-ie 61 | 62 | - name: Get the version 63 | id: version 64 | run: | 65 | VERSION=$(awk '/__version__ = "(.*)"/{print $3}' optimum/tpu/version.py | sed 's/"//g') 66 | echo "version=$VERSION" >> $GITHUB_OUTPUT 67 | 68 | - name: Build and push TGI Docker image 69 | id: push 70 | uses: docker/build-push-action@v6 71 | with: 72 | context: . 73 | file: text-generation-inference/docker/Dockerfile 74 | push: true 75 | tags: ${{ steps.meta.outputs.tags }} 76 | labels: ${{ steps.meta.outputs.labels }} 77 | build-args: | 78 | VERSION=${{ steps.version.outputs.version }} 79 | 80 | 81 | - name: Generate artifact attestation for TGI 82 | uses: actions/attest-build-provenance@v1 83 | with: 84 | subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 85 | subject-digest: ${{ steps.push.outputs.digest }} 86 | push-to-registry: true 87 | 88 | - name: Build and push TGI IE Docker image 89 | id: push-ie 90 | uses: docker/build-push-action@v6 91 | with: 92 | context: . 93 | file: text-generation-inference/docker/Dockerfile 94 | push: true 95 | tags: ${{ steps.meta-ie.outputs.tags }} 96 | labels: ${{ steps.meta-ie.outputs.labels }} 97 | build-args: | 98 | VERSION=${{ steps.version.outputs.version }} 99 | target: inference-endpoint 100 | 101 | 102 | - name: Generate artifact attestation for TGI IE 103 | uses: actions/attest-build-provenance@v1 104 | with: 105 | subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 106 | subject-digest: ${{ steps.push-ie.outputs.digest }} 107 | push-to-registry: true 108 | -------------------------------------------------------------------------------- /docs/source/howto/advanced-tgi-serving.mdx: -------------------------------------------------------------------------------- 1 | # Advanced TGI Server Configuration 2 | 3 | ## Jetstream Pytorch and Pytorch XLA backends 4 | 5 | [Jetstream Pytorch](https://github.com/AI-Hypercomputer/jetstream-pytorch) is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available. 6 | 7 | We recommend using Jetstream with TGI for the best performance. If for some reason you want to use the Pytorch/XLA backend instead, you can set the `JETSTREAM_PT_DISABLE=1` environment variable. 8 | 9 | For more information, see our discussion on the [difference between jetstream and pytorch XLA](../conceptual_guides/difference_between_jetstream_and_xla) 10 | 11 | ## Quantization 12 | When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the `QUANTIZATION=1` environment variable. For instance, on a 2x4 TPU v5e (16GB per chip * 8 = 128 GB per pod), you can serve models up to 70B parameters, such as Llama 3.3-70B. The quantization is done in `int8` on the fly as the weight loads. As with any quantization option, you can expect a small drop in the model accuracy. Without the quantization option enabled, the model is served in bf16. 13 | 14 | ## How to solve memory requirements 15 | 16 | If you encounter `Backend(NotEnoughMemory(2048))`, here are some solutions that could help with reducing memory usage in TGI: 17 | 18 | **Optimum-TPU specific arguments:** 19 | - `-e QUANTIZATION=1`: To enable quantization. This should reduce memory requirements by almost half 20 | - `-e MAX_BATCH_SIZE=n`: You can manually reduce the size of the batch size 21 | 22 | **TGI specific arguments:** 23 | - `--max-input-length`: Maximum input sequence length 24 | - `--max-total-tokens`: Maximum combined input and output tokens 25 | - `--max-batch-prefill-tokens`: Maximum tokens for batch processing 26 | - `--max-batch-total-tokens`: Maximum total tokens in a batch 27 | 28 | To reduce memory usage, you can try smaller values for `--max-input-length`, `--max-total-tokens`, `--max-batch-prefill-tokens`, and `--max-batch-total-tokens`. 29 | 30 | 31 | `max-batch-prefill-tokens ≤ max-input-length * max_batch_size`. Otherwise, you will have an error as the configuration does not make sense. If the max-batch-prefill-tokens were bigger, then you would not be able to process any request 32 | 33 | 34 | ## Sharding 35 | Sharding is done automatically by the TGI server, so your model uses all the TPUs that are available. We do tensor parallelism, so the layers are automatically split in all available TPUs. However, the TGI router will only see one shard. 36 | 37 | More information on tensor parralelsim can be found here https://huggingface.co/docs/text-generation-inference/conceptual/tensor_parallelism. 38 | 39 | ## Understanding the configuration 40 | 41 | Key parameters explained: 42 | 43 | **Required parameters** 44 | - `--shm-size 16GB`: Increase default shared memory allocation. 45 | - `--privileged`: Required for TPU access. 46 | - `--net host`: Uses host network mode. 47 | Those are needed to run a TPU container so that the container can properly access the TPU hardware. 48 | 49 | **Optional parameters** 50 | - `-v ~/hf_data:/data`: Volume mount for model storage, this allows you to not have to re-download the models weights on each startup. You can use any folder you would like as long as it maps back to /data. 51 | - `-e SKIP_WARMUP=1`: Disables warmup for quick testing (not recommended for production). 52 | Those are parameters used by TGI and optimum-TPU to configure the server behavior. 53 | 54 | 55 | 56 | `--privileged --shm-size 16GB --net host` is required as specify in https://github.com/pytorch/xla 57 | 58 | 59 | ## Next steps 60 | Please check the [TGI docs](https://huggingface.co/docs/text-generation-inference) for more TGI server configuration options. -------------------------------------------------------------------------------- /optimum/tpu/fsdp_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Utility functions to provide FSDPv2 configuration for TPU training. 17 | """ 18 | from typing import Any, Dict, List, Union 19 | 20 | 21 | PreTrainedModel = Any 22 | # NOTE: instead of the above, modeling_utils.PreTrainedModel should be used, but since the usage is only for type 23 | # hinting, it is not imported here, so to avoid pulling imports of torch_xla. 24 | 25 | 26 | def use_fsdp_v2(): 27 | """ 28 | Enable FSDPv2 for TPU training. 29 | """ 30 | import torch_xla.runtime as xr 31 | 32 | # FSDPv2 requires SPMD to be enabled. 33 | xr.use_spmd() 34 | 35 | 36 | def get_fsdp_config(*cls_to_wrap: Union[str | List[str]]) -> Dict: 37 | """ 38 | Returns the FSDPv2 configuration for a given class to wrap. 39 | 40 | Args: 41 | cls_to_wrap: One or more class names to wrap with FSDPv2. 42 | 43 | Returns: 44 | A dictionary with the FSDPv2 configuration. 45 | """ 46 | return { 47 | "transformer_layer_cls_to_wrap": [*cls_to_wrap], 48 | "xla": True, 49 | "xla_fsdp_v2": True, 50 | "xla_fsdp_grad_ckpt": True, 51 | } 52 | 53 | 54 | def _unwrap_model(model: PreTrainedModel) -> PreTrainedModel: 55 | """ 56 | Unwraps the model from the PeftModel wrapper. 57 | 58 | Args: 59 | model: The model to unwrap. 60 | 61 | Returns: 62 | The unwrapped model. 63 | """ 64 | try: 65 | from peft.peft_model import LoraModel, PeftModel 66 | 67 | if isinstance(model, PeftModel) and isinstance(model.base_model, LoraModel): 68 | return model.base_model.model 69 | return model 70 | except ImportError: 71 | return model 72 | 73 | 74 | def get_fsdp_training_args(model: PreTrainedModel) -> Dict: 75 | """ 76 | Returns the default FSDPv2 training arguments for a model of a known class. 77 | 78 | Args: 79 | model: The model to train with FSDPv2. 80 | 81 | Returns: 82 | A dictionary with the FSDPv2 training arguments. 83 | """ 84 | model = _unwrap_model(model) 85 | model_type = model.config.model_type 86 | matched_model = False 87 | if model_type == "gemma": 88 | from transformers import GemmaForCausalLM as HFGemmaForCausalLLM 89 | 90 | from .modeling_gemma import GemmaForCausalLM 91 | 92 | if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM): 93 | cls_to_wrap = "GemmaDecoderLayer" 94 | matched_model = True 95 | elif model_type == "llama": 96 | from transformers import LlamaForCausalLM as HFLlamaForCausalLLM 97 | 98 | from .modeling_llama import LlamaForCausalLM 99 | 100 | if isinstance(model, LlamaForCausalLM) or isinstance(model, HFLlamaForCausalLLM): 101 | cls_to_wrap = "LlamaDecoderLayer" 102 | matched_model = True 103 | 104 | if not matched_model: 105 | raise ValueError(f"Model {model} configuration cannot be auto-generated, use get_fsdp_config instead.") 106 | 107 | fsdp_training_args = { 108 | "fsdp": "full_shard", 109 | "fsdp_config": get_fsdp_config(cls_to_wrap), 110 | } 111 | return fsdp_training_args 112 | -------------------------------------------------------------------------------- /text-generation-inference/tests/test_decode_jetstream.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | from decode_tests_utils import DecodeTestParams, decode_single_test 4 | 5 | 6 | # All tests in this file are for jetstream 7 | pytestmark = pytest.mark.jetstream 8 | 9 | @pytest.mark.slow 10 | @pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"]) 11 | @pytest.mark.parametrize("params", 12 | [ 13 | DecodeTestParams( 14 | model_id="meta-llama/Llama-2-7b-hf", 15 | sequence_length=256, 16 | expected_text="\nWinston Smith, his chin nuzzled into his breast in an effort to escape", 17 | top_k=100, 18 | ), 19 | DecodeTestParams( 20 | model_id="meta-llama/Meta-Llama-3-8B", 21 | sequence_length=256, 22 | expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,", 23 | top_k=100, 24 | ), 25 | DecodeTestParams( 26 | model_id="google/gemma-7b", 27 | sequence_length=128, 28 | expected_text="\n\nThe year was 1984.\n\nThe place was Oceania.\n\nThe man was", 29 | ), 30 | DecodeTestParams( 31 | model_id="mistralai/Mixtral-8x7B-v0.1", 32 | sequence_length=1024, 33 | expected_text="\n\nGeorge Orwell, 1984\n\nThe clocks are striking thirteen", 34 | ), 35 | ], 36 | ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"], 37 | ) 38 | def test_decode_single_jetstream_pytorch_slow(params, do_sample): 39 | params.do_sample = do_sample 40 | decode_single_test(params) 41 | 42 | 43 | @pytest.mark.parametrize("do_sample", [False, True], ids=["greedy", "sample"]) 44 | @pytest.mark.parametrize("params", 45 | [ 46 | DecodeTestParams( 47 | model_id="Maykeye/TinyLLama-v0", 48 | sequence_length=256, 49 | expected_text=" The sun was shining and the sky was shining.\nSuddenly, a big wind came and blew the wind away.", 50 | max_new_tokens=25, 51 | ), 52 | DecodeTestParams( 53 | model_id="google/gemma-2b", 54 | sequence_length=1024, 55 | expected_text="\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never", 56 | ), 57 | DecodeTestParams( 58 | model_id="dacorvo/Mixtral-tiny", # This is a random tiny model, just to test model can be loaded. 59 | sequence_length=512, 60 | expected_text="манaminationVariableßer Rog malesazine longふ Toy Champions enero Facereverse▲verbose prosecut literally disappearedअ", 61 | ), 62 | DecodeTestParams( 63 | # NOTE: this test is interesting because it is a fine-tuned model that requires padding on weights to work. 64 | model_id="Trendyol/Trendyol-LLM-7b-base-v0.1", 65 | sequence_length=512, 66 | expected_text="\nThe clocks were striking thirteen, and the clocks were striking thirteen.", 67 | ), 68 | DecodeTestParams( 69 | model_id="meta-llama/Llama-3.2-1B", 70 | sequence_length=256, 71 | expected_text=" Winston Smith, his chin nuzzled into his breast, stretched, and looked out across the city", 72 | max_new_tokens=20, 73 | ) 74 | ], 75 | ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny", "Trendyol-LLM-7b-base-v0.1", "Llama-3.2-1B"], 76 | ) 77 | def test_decode_single_jetstream_pytorch(params, do_sample): 78 | params.do_sample = do_sample 79 | decode_single_test(params) 80 | 81 | 82 | def test_decode_repetition_penalty_jetstream_pytorch(): 83 | """Test if the repetition penalty generates something without crashing.""" 84 | params = DecodeTestParams( 85 | model_id="Maykeye/TinyLLama-v0", 86 | sequence_length=256, 87 | expected_text=" The sun was shining and it was very hot.\nSuddenly, a big wind came and", 88 | repetition_penalty=1.2 89 | ) 90 | decode_single_test(params) 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Optimum-TPU 4 | =========================== 5 |

Take the most out of Google Cloud TPUs with the ease of 🤗 transformers

6 | 7 | [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://huggingface.co/docs/optimum/index) 8 | [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) 9 | [![Optimum TPU / Test TGI on TPU](https://github.com/huggingface/optimum-tpu/actions/workflows/test-pytorch-xla-tpu-tgi.yml/badge.svg)](https://github.com/huggingface/optimum-tpu/actions/workflows/test-pytorch-xla-tpu-tgi.yml) 10 |
11 | 12 | > [!CAUTION] 13 | > **🚧 Optimum-TPU is now in maintenance mode.** 14 | > 15 | > We’ll continue to welcome community contributions for minor bug fixes, documentation improvements, and lightweight maintenance tasks. 16 | > 17 | > Optimum-TPU was created to make it easier to train and run inference on TPUs using 🤗 Transformers and 🤗 Accelerate. Thanks to everyone who has contributed and supported the project! ❤️ 18 | > 19 | > While this repository is no longer under active development, you can continue exploring TPU solutions with: 20 | > • [tpu-inference](https://github.com/vllm-project/tpu-inference) for inference 21 | > • [🤗 Accelerate](https://github.com/huggingface/accelerate) for training 22 | > 23 | > Thank you for being part of the journey! 🚀 24 | 25 | [Tensor Processing Units (TPU)](https://cloud.google.com/tpu) are AI accelerator made by Google to optimize 26 | performance and cost from AI training to inference. 27 | 28 | This repository exposes an interface similar to what Hugging Face transformers library provides to interact with 29 | a magnitude of models developed by research labs, institutions and the community. 30 | 31 | We aim at providing our user the best possible performances targeting Google Cloud TPUs for both training and inference 32 | working closely with Google and Google Cloud to make this a reality. 33 | 34 | 35 | ## Supported Model and Tasks 36 | 37 | We currently support a few LLM models targeting text generation scenarios: 38 | - 💎 Gemma (2b, 7b) 39 | - 🦙 Llama2 (7b) and Llama3 (8b). On Text Generation Inference with Jetstream Pytorch, also Llama3.1, Llama3.2 and Llama3.3 (text-only models) are supported, up to 70B parameters. 40 | - 💨 Mistral (7b) 41 | 42 | 43 | ## Installation 44 | 45 | `optimum-tpu` comes with an handy PyPi released package compatible with your classical python dependency management tool. 46 | 47 | `pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html` 48 | 49 | `export PJRT_DEVICE=TPU` 50 | 51 | 52 | ## Inference 53 | 54 | `optimum-tpu` provides a set of dedicated tools and integrations in order to leverage Cloud TPUs for inference, especially 55 | on the latest TPU version `v5e` and `v6e`. 56 | 57 | Other TPU versions will be supported along the way. 58 | 59 | ### Text-Generation-Inference 60 | 61 | As part of the integration, we do support a [text-generation-inference (TGI)](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference) backend allowing to deploy and serve 62 | incoming HTTP requests and execute them on Cloud TPUs. 63 | 64 | Please see the [TGI specific documentation](text-generation-inference) on how to get started. 65 | 66 | ### JetStream Pytorch Engine 67 | 68 | `optimum-tpu` provides an optional support of JetStream Pytorch engine inside of TGI. This support can be installed using the dedicated CLI command: 69 | 70 | ```shell 71 | optimum-tpu install-jetstream-pytorch 72 | ``` 73 | 74 | To enable the support, export the environment variable `JETSTREAM_PT=1`. 75 | 76 | ## Training 77 | 78 | Fine-tuning is supported and tested on the TPU `v5e`. We have tested so far: 79 | 80 | - 🦙 Llama-2 7B, Llama-3 8B and newer; 81 | - 💎 Gemma 2B and 7B. 82 | 83 | You can check the examples: 84 | 85 | - [Fine-Tune Gemma on Google TPU](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/gemma_tuning.ipynb) 86 | - The [Llama fine-tuning script](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/llama_tuning.ipynb) 87 | -------------------------------------------------------------------------------- /docs/source/tutorials/training_on_tpu.mdx: -------------------------------------------------------------------------------- 1 | # First TPU Training on Google Cloud 2 | 3 | This tutorial walks you through setting up and running model training on TPU using the `optimum-tpu` package. 4 | 5 | ## Prerequisites 6 | 7 | Before starting, ensure you have a running TPU instance (see [TPU Setup Guide](../tutorials/tpu_setup.mdx)) 8 | 9 | ## Environment Setup 10 | 11 | First, create and activate a virtual environment: 12 | ```bash 13 | python -m venv .venv 14 | source .venv/bin/activate 15 | ``` 16 | 17 | Install the required packages: 18 | ```bash 19 | # Install optimum-tpu with PyTorch/XLA support 20 | pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html 21 | 22 | # Install additional training dependencies 23 | pip install transformers datasets accelerate trl peft evaluate 24 | ``` 25 | 26 | ## Understanding FSDP for TPU Training 27 | To speed up your training on TPU, you can rely on Optimum TPU's integration with FSDP (Fully Sharded Data Parallel). When training large models, FSDP automatically shards (splits) your model across all available TPU workers, providing several key benefits: 28 | 1. Memory efficiency: Each TPU worker only stores a portion of the model parameters, reducing per-device memory requirements 29 | 2. Automatic scaling: FSDP handles the complexity of distributing the model and aggregating gradients 30 | 3. Performance optimization: Optimum TPU's implementation is specifically tuned for TPU hardware 31 | 32 | This sharding happens automatically when you use the `fsdp_v2.get_fsdp_training_args(model)` configuration in your training setup, making it easy to train larger models that wouldn't fit on a single TPU device. 33 | 34 | ## How to Setup FSDP 35 | 36 | The key modification to enable FSDP is just these few lines: 37 | 38 | ```diff 39 | +from optimum.tpu import fsdp_v2 40 | +fsdp_v2.use_fsdp_v2() 41 | +fsdp_training_args = fsdp_v2.get_fsdp_training_args(model) 42 | ``` 43 | 44 | Then include these arguments in your trainer configuration: 45 | 46 | ```diff 47 | trainer = SFTTrainer( 48 | model=model, 49 | train_dataset=dataset, 50 | args=TrainingArguments( 51 | ... 52 | + dataloader_drop_last=True, # Required for FSDPv2 53 | + **fsdp_training_args, 54 | ), 55 | ... 56 | ) 57 | ``` 58 | 59 | ## Complete example 60 | 61 | Here's a full working example that demonstrates TPU training with FSDP: 62 | 63 | ```python 64 | import torch 65 | from datasets import load_dataset 66 | from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments 67 | from peft import LoraConfig 68 | from trl import SFTTrainer 69 | from optimum.tpu import fsdp_v2 70 | 71 | # Enable FSDPv2 for TPU 72 | fsdp_v2.use_fsdp_v2() 73 | 74 | # Load model and dataset 75 | model_id = "google/gemma-2b" 76 | model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) 77 | dataset = load_dataset("tatsu-lab/alpaca", split="train[:1000]") 78 | 79 | # Get FSDP training arguments 80 | fsdp_training_args = fsdp_v2.get_fsdp_training_args(model) 81 | 82 | # Create trainer with minimal configuration 83 | trainer = SFTTrainer( 84 | model=model, 85 | train_dataset=dataset, 86 | args=TrainingArguments( 87 | output_dir="./output", 88 | dataloader_drop_last=True, # Required for FSDPv2 89 | **fsdp_training_args, 90 | ), 91 | peft_config=LoraConfig( 92 | r=8, 93 | target_modules=["k_proj", "v_proj"], 94 | task_type="CAUSAL_LM", 95 | ), 96 | ) 97 | 98 | # Start training 99 | trainer.train() 100 | ``` 101 | 102 | Save this code as train.py and run it: 103 | 104 | ``` 105 | python train.py 106 | ``` 107 | 108 | You should now see the loss decrease during training. When the training is done, you will have a fine-tuned model. Congrats - you've just trained your first model on TPUs! 🙌 109 | 110 | ## Next Steps 111 | Continue your TPU training journey by exploring: 112 | - More complex training scenarios in our [examples](../howto/more_examples) 113 | - Different [model architectures supported by Optimum TPU](../supported-architectures) 114 | 115 | -------------------------------------------------------------------------------- /optimum/tpu/cli.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import os 3 | import shutil 4 | import subprocess 5 | import sys 6 | from pathlib import Path 7 | 8 | import click 9 | import typer 10 | 11 | 12 | TORCH_VER = "2.5.1" 13 | JETSTREAM_PT_VER = "jetstream-v0.2.4" 14 | DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps") 15 | 16 | app = typer.Typer() 17 | 18 | 19 | def _check_module(module_name: str): 20 | spec = importlib.util.find_spec(module_name) 21 | return spec is not None 22 | 23 | 24 | def _run(cmd: str): 25 | split_cmd = cmd.split() 26 | subprocess.check_call(split_cmd) 27 | 28 | 29 | def _install_torch_cpu(): 30 | # install torch CPU version to avoid installing CUDA dependencies 31 | _run(sys.executable + f" -m pip install torch=={TORCH_VER} --index-url https://download.pytorch.org/whl/cpu") 32 | 33 | 34 | @app.command() 35 | def install_pytorch_xla( 36 | force: bool = False, 37 | ): 38 | """ 39 | Installs PyTorch XLA with TPU support. 40 | 41 | Args: 42 | force (bool): When set, force reinstalling even if Pytorch XLA is already installed. 43 | """ 44 | if not force and _check_module("torch") and _check_module("torch_xla"): 45 | typer.confirm( 46 | "PyTorch XLA is already installed. Do you want to reinstall it?", 47 | default=False, 48 | abort=True, 49 | ) 50 | _install_torch_cpu() 51 | _run( 52 | sys.executable 53 | + f" -m pip install torch-xla[tpu]=={TORCH_VER} -f https://storage.googleapis.com/libtpu-releases/index.html" 54 | ) 55 | click.echo() 56 | click.echo(click.style("PyTorch XLA has been installed.", bold=True)) 57 | 58 | 59 | @app.command() 60 | def install_jetstream_pytorch( 61 | deps_path: str = DEFAULT_DEPS_PATH, 62 | yes: bool = False, 63 | ): 64 | """ 65 | Installs Jetstream Pytorch with TPU support. 66 | 67 | Args: 68 | deps_path (str): Path where Jetstream Pytorch dependencies will be installed. 69 | yes (bool): When set, proceed installing without asking questions. 70 | """ 71 | if not _check_module("torch"): 72 | _install_torch_cpu() 73 | if not yes and _check_module("jetstream_pt") and _check_module("torch_xla2"): 74 | typer.confirm( 75 | "Jetstream Pytorch is already installed. Do you want to reinstall it?", 76 | default=False, 77 | abort=True, 78 | ) 79 | 80 | jetstream_repo_dir = os.path.join(deps_path, "jetstream-pytorch") 81 | if not yes and os.path.exists(jetstream_repo_dir): 82 | typer.confirm( 83 | f"Directory {jetstream_repo_dir} already exists. Do you want to delete it and reinstall Jetstream Pytorch?", 84 | default=False, 85 | abort=True, 86 | ) 87 | shutil.rmtree(jetstream_repo_dir, ignore_errors=True) 88 | # Create the directory if it does not exist 89 | os.makedirs(deps_path, exist_ok=True) 90 | # Clone and install Jetstream Pytorch 91 | os.chdir(deps_path) 92 | _run("git clone https://github.com/google/jetstream-pytorch.git") 93 | os.chdir("jetstream-pytorch") 94 | _run(f"git checkout {JETSTREAM_PT_VER}") 95 | _run("git submodule update --init --recursive") 96 | # We cannot install in a temporary directory because the directory should not be deleted after the script finishes, 97 | # because it will install its dependendencies from that directory. 98 | _run(sys.executable + " -m pip install -e .") 99 | 100 | _run( 101 | sys.executable 102 | + f" -m pip install torch_xla[pallas]=={TORCH_VER} " 103 | + " -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html" 104 | + " -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html" 105 | + " -f https://storage.googleapis.com/libtpu-releases/index.html" 106 | ) 107 | # Install PyTorch XLA pallas 108 | click.echo() 109 | click.echo(click.style("Jetstream Pytorch has been installed.", bold=True)) 110 | 111 | 112 | if __name__ == "__main__": 113 | sys.exit(app()) 114 | -------------------------------------------------------------------------------- /optimum/tpu/distributed_model.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: E402 2 | import os 3 | from enum import Enum 4 | 5 | from loguru import logger 6 | 7 | 8 | os.environ["PJRT_DEVICE"] = "TPU" 9 | 10 | import torch.multiprocessing as mp 11 | import torch_xla.core.xla_model as xm 12 | import torch_xla.distributed.xla_multiprocessing as xmp 13 | 14 | from optimum.tpu.modeling import AutoModelForCausalLM 15 | 16 | from .xla_mp_comm import AgentMailbox, RootMailbox 17 | 18 | 19 | class ModelCommand(Enum): 20 | LEAVE = 0 21 | PREFILL = 1 22 | DECODE = 2 23 | 24 | 25 | def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable): 26 | device = xm.xla_device() 27 | world_size = xm.xrt_world_size() 28 | # create agent mailbox out of root's one 29 | mailbox = AgentMailbox(root_mailbox) 30 | 31 | logger.debug( 32 | f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} " 33 | + f"world size {world_size}" 34 | ) 35 | 36 | # Model loading and sharding should happen here 37 | model = AutoModelForCausalLM.from_pretrained(model_id) 38 | model = model.eval() 39 | model.to(device) 40 | 41 | def get_next_token(inputs): 42 | # move inputs to device in a new dict to avoid conflicts 43 | model_inputs = {} 44 | for key, value in inputs.items(): 45 | model_inputs[key] = value.to(device) 46 | outputs = model(**model_inputs, return_dict=False)[0] 47 | xm.mark_step() 48 | # consider adding a rendezvous here 49 | if rank == 0: 50 | logger.debug(f"Rank {rank} getting tokens") 51 | next_token = sample_fn(outputs) 52 | xm.mark_step() 53 | logger.debug(f"Rank {rank} sending next_tokens {next_token.shape}") 54 | # Data needs to be moved to CPU before setting it 55 | mailbox.send(next_token.cpu()) 56 | 57 | while True: 58 | if rank == 0: 59 | mailbox.agent_ready.set() 60 | logger.debug(f"Rank {rank} waiting for commands") 61 | mailbox.receive() 62 | # Wait for rank 0 to receive command 63 | xm.rendezvous("start") 64 | 65 | logger.debug(f"Rank {rank} waiting for command at rendezvous") 66 | command, data = mailbox.command_data 67 | inputs = data[0] if data else None 68 | if command == ModelCommand.PREFILL: 69 | logger.debug(f"Rank {rank} PREFILL") 70 | get_next_token(inputs) 71 | elif command == ModelCommand.DECODE: 72 | logger.debug(f"Rank {rank} DECODE") 73 | get_next_token(inputs) 74 | elif command == ModelCommand.LEAVE: 75 | logger.debug(f"Rank {rank} LEAVE") 76 | # Set model to ready 77 | mailbox.agent_ready.set() 78 | break 79 | 80 | 81 | def model_loop_fn(*args): 82 | """Spawn processes in the TPUs forwarding arguments""" 83 | xmp.spawn(_mp_fn, args=(args), join=True, daemon=False) 84 | 85 | 86 | class DistributedModel: 87 | def __init__(self, model_id: str, sample_fn: callable): 88 | manager = mp.Manager() 89 | self.mailbox = RootMailbox(manager) 90 | 91 | self.model_loop = mp.Process(target=model_loop_fn, args=(model_id, self.mailbox, sample_fn)) 92 | self.model_loop.start() 93 | 94 | def prefill(self, **model_args): 95 | assert self.mailbox is not None, "DistributedModel is not initialized" 96 | return self.mailbox.send(ModelCommand.PREFILL, model_args)[0] 97 | 98 | def decode(self, **model_args): 99 | assert self.mailbox is not None, "DistributedModel is not initialized" 100 | return self.mailbox.send(ModelCommand.PREFILL, model_args)[0] 101 | 102 | def leave(self): 103 | if self.mailbox is None: 104 | return 105 | self.mailbox.send(ModelCommand.LEAVE) 106 | logger.debug("Joining...") 107 | self.model_loop.join() 108 | logger.debug("Model loop finished") 109 | self.mailbox = None 110 | 111 | def __del__(self): 112 | self.leave() 113 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from pathlib import Path 3 | from typing import List 4 | 5 | from grpc import aio 6 | from grpc_reflection.v1alpha import reflection 7 | from loguru import logger 8 | 9 | from .auto_generator import AutoGenerator, Generator 10 | from .interceptor import ExceptionInterceptor 11 | from .pb import generate_pb2, generate_pb2_grpc 12 | 13 | 14 | class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): 15 | def __init__(self, generator: Generator, server_urls: List[str]): 16 | self.generator = generator 17 | self.server_urls = server_urls 18 | 19 | async def Info(self, request, context): 20 | logger.debug("Info") 21 | return self.generator.info 22 | 23 | async def Health(self, request, context): 24 | logger.debug("Health") 25 | return generate_pb2.HealthResponse() 26 | 27 | async def ServiceDiscovery(self, request, context): 28 | logger.debug("ServiceDiscovery") 29 | return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) 30 | 31 | async def ClearCache(self, request, context): 32 | logger.debug("ClearCache") 33 | if request.HasField("id"): 34 | self.generator.clear(request.id) 35 | else: 36 | self.generator.clear() 37 | return generate_pb2.ClearCacheResponse() 38 | 39 | async def FilterBatch(self, request, context): 40 | logger.debug("FilterBatch") 41 | filtered_batch = self.generator.filter(request.batch_id, request.request_ids) 42 | return generate_pb2.FilterBatchResponse(batch=filtered_batch) 43 | 44 | async def Warmup(self, request, context): 45 | logger.info("Warmup (this can take several minutes)") 46 | max_tokens = self.generator.warmup(request.batch) 47 | ret = generate_pb2.WarmupResponse(max_supported_total_tokens=max_tokens) 48 | logger.info("Warmup done") 49 | return ret 50 | 51 | async def Prefill(self, request, context): 52 | logger.debug("Prefill") 53 | batch = request.batch 54 | generations, batch = self.generator.prefill(request.batch) 55 | return generate_pb2.PrefillResponse(generations=generations, batch=batch) 56 | 57 | async def Decode(self, request, context): 58 | logger.debug("Decode") 59 | generations, batch = self.generator.decode(request.batches) 60 | return generate_pb2.DecodeResponse(generations=generations, batch=batch) 61 | 62 | 63 | def serve( 64 | model_path: str, 65 | revision: str, 66 | max_batch_size: int, 67 | max_sequence_length: int, 68 | max_input_tokens: int, 69 | uds_path: Path, 70 | ): 71 | async def serve_inner(model_path: str): 72 | unix_socket_template = "unix://{}-{}" 73 | local_url = unix_socket_template.format(uds_path, 0) 74 | server_urls = [local_url] 75 | 76 | try: 77 | generator = AutoGenerator.from_pretrained( 78 | model_path, 79 | revision=revision, 80 | max_batch_size=max_batch_size, 81 | max_sequence_length=max_sequence_length, 82 | max_input_tokens=max_input_tokens, 83 | ) 84 | except Exception: 85 | logger.exception("Error when initializing model") 86 | raise 87 | 88 | server = aio.server(interceptors=[ExceptionInterceptor()]) 89 | generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( 90 | TextGenerationService(generator, server_urls), server 91 | ) 92 | SERVICE_NAMES = ( 93 | generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, 94 | reflection.SERVICE_NAME, 95 | ) 96 | reflection.enable_server_reflection(SERVICE_NAMES, server) 97 | server.add_insecure_port(local_url) 98 | 99 | await server.start() 100 | 101 | logger.info("Server started at {}".format(local_url)) 102 | 103 | try: 104 | await server.wait_for_termination() 105 | except KeyboardInterrupt: 106 | logger.info("Signal received. Shutting down") 107 | await server.stop(0) 108 | 109 | asyncio.run(serve_inner(model_path)) 110 | -------------------------------------------------------------------------------- /docs/source/tutorials/inference_on_tpu.mdx: -------------------------------------------------------------------------------- 1 | # First TPU Inference on Google Cloud 2 | 3 | This tutorial guides you through setting up and running inference on TPU using Text Generation Inference (TGI) ([documentation](https://huggingface.co/docs/text-generation-inference)). TGI server is compatible with OpenAI messages API, and it offers an optimized solution for serving models on TPU. 4 | 5 | ## Prerequisites 6 | 7 | Before starting, ensure you have: 8 | - A running TPU instance (see [TPU Setup Guide](../tutorials/tpu_setup)) 9 | - SSH access to your TPU instance 10 | - A HuggingFace account 11 | 12 | ## Step 1: Initial Setup 13 | 14 | ### SSH Access 15 | First, connect to your TPU instance via SSH. 16 | 17 | 18 | ### Install Required Tools 19 | 20 | Install the HuggingFace Hub CLI: 21 | ```bash 22 | pip install huggingface_hub 23 | ``` 24 | 25 | ### Authentication 26 | 27 | Log in to HuggingFace: 28 | ```bash 29 | huggingface-cli login 30 | ``` 31 | 32 | ## Step 2: Model Deployment 33 | 34 | ### Model Selection 35 | 36 | We will use the `gemma-2b-it` model for this tutorial: 37 | 1. Visit https://huggingface.co/google/gemma-2b-it 38 | 2. Accept the model terms and conditions 39 | 3. This enables model download access 40 | 41 | ### Launch TGI Server 42 | 43 | We will use the Optimum-TPU image, a TPU-optimized TGI image provided by HuggingFace. 44 | 45 | ```bash 46 | docker run -p 8080:80 \ 47 | --shm-size 16GB \ 48 | --privileged \ 49 | --net host \ 50 | -e LOG_LEVEL=text_generation_router=debug \ 51 | -v ~/hf_data:/data \ 52 | -e HF_TOKEN=$(cat ~/.cache/huggingface/token) \ 53 | ghcr.io/huggingface/optimum-tpu:v0.2.3-tgi \ 54 | --model-id google/gemma-2b-it \ 55 | --max-input-length 512 \ 56 | --max-total-tokens 1024 \ 57 | --max-batch-prefill-tokens 512 \ 58 | --max-batch-total-tokens 1024 59 | ``` 60 | 61 | ### Understanding the Configuration 62 | 63 | Key parameters explained: 64 | - `--shm-size 16GB --privileged --net=host`: Required for docker to access the TPU 65 | - `-v ~/hf_data:/data`: Volume mount for model storage 66 | - `--max-input-length`: Maximum input sequence length 67 | - `--max-total-tokens`: Maximum combined input and output tokens 68 | - `--max-batch-prefill-tokens`: Maximum tokens for batch processing 69 | - `--max-batch-total-tokens`: Maximum total tokens in a batch 70 | 71 | ## Step 3: Making Inference Requests 72 | 73 | ### Server Readiness 74 | Wait for the "Connected" message in the logs: 75 | 76 | ``` 77 | 2025-01-11T10:40:00.256056Z INFO text_generation_router::server: router/src/server.rs:2393: Connected 78 | ``` 79 | 80 | Your TGI server is now ready to serve requests. 81 | 82 | ### Testing from the TPU VM 83 | 84 | Query the server from another terminal on the TPU instance: 85 | 86 | ```bash 87 | curl 0.0.0.0:8080/generate \ 88 | -X POST \ 89 | -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ 90 | -H 'Content-Type: application/json' 91 | ``` 92 | 93 | ### Remote Access 94 | 95 | To query from outside the TPU instance: 96 | 97 | ![External IP TPU](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/get_external_ip_tpu.png) 98 | 99 | 1. Find your TPU's external IP in Google Cloud Console 100 | 2. Replace the IP in the request: 101 | ```bash 102 | curl 34.174.11.242:8080/generate \ 103 | -X POST \ 104 | -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ 105 | -H 'Content-Type: application/json' 106 | ``` 107 | 108 | #### (Optional) Firewall Configuration 109 | 110 | You may need to configure GCP firewall rules to allow remote access: 111 | 1. Use `gcloud compute firewall-rules create` to allow traffic 112 | 2. Ensure port 8080 is accessible 113 | 3. Consider security best practices for production 114 | 115 | ## Request Parameters 116 | 117 | Key parameters for inference requests: 118 | - `inputs`: The prompt text 119 | - `max_new_tokens`: Maximum number of tokens to generate 120 | - Additional parameters available in [TGI documentation](https://huggingface.co/docs/text-generation-inference) 121 | 122 | ## Next Steps 123 | 124 | 1. Please check the [TGI Consuming Guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/consuming_tgi) to learn about how to query your new TGI server. 125 | 2. Check the rest of our documentation for advanced settings that can be used on your new TGI server. 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /text-generation-inference/integration-tests/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict 3 | 4 | import Levenshtein 5 | import pytest 6 | 7 | 8 | MODEL_CONFIGS = { 9 | "gpt2": { 10 | "model_id": "openai-community/gpt2", 11 | "sequence_length": 1024, 12 | "expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a while", 13 | "expected_sampling_output": 'The fundamental concepts of deep learning are the same as those used to train and understand your first language, or your first set of skills', 14 | "expected_batch_output": "\n\nDeep learning is a technique that allows you to learn something from a single source", 15 | "args": [ 16 | "--max-input-length", "512", 17 | "--max-total-tokens", "1024", 18 | "--max-batch-prefill-tokens", "512", 19 | "--max-batch-total-tokens", "1024" 20 | ], 21 | "env_config": { 22 | "MAX_BATCH_SIZE": "4", 23 | "JETSTREAM_PT_DISABLE": "1", 24 | "SKIP_WARMUP": "1", 25 | } 26 | }, 27 | "gemma": { 28 | "model_id": "google/gemma-2b-it", 29 | "sequence_length": 1024, 30 | "expected_greedy_output": "\n\nDeep learning is a subfield of machine learning that allows computers to learn from data", 31 | "expected_sampling_output": "\n\n**Deep learning** is a subfield of machine learning that enables computers to learn from data without explicit programming", 32 | "expected_batch_output": "\n\nDeep learning is a subfield of machine learning that allows computers to learn from data", 33 | "args": [ 34 | "--max-input-length", "512", 35 | "--max-total-tokens", "1024", 36 | "--max-batch-prefill-tokens", "512", 37 | "--max-batch-total-tokens", "1024" 38 | ], 39 | "env_config": { 40 | "MAX_BATCH_SIZE": "4", 41 | "SKIP_WARMUP": "1", 42 | } 43 | } 44 | } 45 | 46 | @pytest.fixture(scope="module", params=MODEL_CONFIGS.keys()) 47 | def model_config(request) -> Dict[str, Any]: 48 | """Fixture that provides model configurations for testing.""" 49 | return MODEL_CONFIGS[request.param] 50 | 51 | @pytest.fixture(scope="module") 52 | def model_name_or_path(model_config): 53 | os.environ["HF_SEQUENCE_LENGTH"] = str(model_config["sequence_length"]) 54 | yield model_config["model_id"] 55 | 56 | @pytest.fixture(scope="module") 57 | def tgi_service(launcher, model_name_or_path): 58 | with launcher(model_name_or_path) as tgi_service: 59 | yield tgi_service 60 | 61 | @pytest.fixture(scope="module") 62 | async def tgi_client(tgi_service): 63 | await tgi_service.health(1000) 64 | return tgi_service.client 65 | 66 | @pytest.fixture(scope="module") 67 | def expected_outputs(model_config): 68 | return { 69 | "greedy": model_config["expected_greedy_output"], 70 | "sampling": model_config["expected_sampling_output"], 71 | "batch": model_config["expected_batch_output"] 72 | } 73 | 74 | @pytest.mark.asyncio 75 | async def test_model_single_request(tgi_client, expected_outputs): 76 | # Bounded greedy decoding without input 77 | response = await tgi_client.generate( 78 | "What is Deep Learning?", 79 | max_new_tokens=17, 80 | decoder_input_details=True, 81 | ) 82 | assert response.details.generated_tokens == 17 83 | assert response.generated_text == expected_outputs["greedy"] 84 | 85 | # Bounded greedy decoding with input 86 | response = await tgi_client.generate( 87 | "What is Deep Learning?", 88 | max_new_tokens=17, 89 | return_full_text=True, 90 | decoder_input_details=True, 91 | ) 92 | assert response.details.generated_tokens == 17 93 | assert response.generated_text == f"What is Deep Learning?{expected_outputs['greedy']}" 94 | 95 | # Sampling 96 | response = await tgi_client.generate( 97 | "What is Deep Learning?", 98 | do_sample=True, 99 | top_k=50, 100 | top_p=0.9, 101 | repetition_penalty=1.2, 102 | max_new_tokens=100, 103 | seed=42, 104 | decoder_input_details=True, 105 | ) 106 | 107 | assert expected_outputs["sampling"] in response.generated_text 108 | 109 | @pytest.mark.asyncio 110 | async def test_model_multiple_requests(tgi_client, generate_load, expected_outputs): 111 | num_requests = 4 112 | responses = await generate_load( 113 | tgi_client, 114 | "What is Deep Learning?", 115 | max_new_tokens=17, 116 | n=num_requests, 117 | ) 118 | 119 | assert len(responses) == 4 120 | expected = expected_outputs["batch"] 121 | for r in responses: 122 | assert r.details.generated_tokens == 17 123 | # Compute the similarity with the expectation using the levenshtein distance 124 | # We should not have more than two substitutions or additions 125 | assert Levenshtein.distance(r.generated_text, expected) < 3 126 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Optional 4 | 5 | import typer 6 | from loguru import logger 7 | 8 | 9 | app = typer.Typer() 10 | 11 | 12 | @app.command() 13 | def serve( 14 | model_id: str, 15 | revision: Optional[str] = None, 16 | sharded: bool = False, 17 | trust_remote_code: bool = None, 18 | uds_path: str = "/tmp/text-generation-server", 19 | logger_level: str = "INFO", 20 | json_output: bool = False, 21 | otlp_service_name: str = "text-generation-inference.server", 22 | max_input_tokens: Optional[int] = None, 23 | ): 24 | """This is the main entry-point for the server CLI. 25 | 26 | Args: 27 | model_id (`str`): 28 | The *model_id* of a model on the HuggingFace hub or the path to a local model. 29 | revision (`Optional[str]`, defaults to `None`): 30 | The revision of the model on the HuggingFace hub. 31 | sharded (`bool`): 32 | Whether the model must be sharded or not. Kept for compatibility with the 33 | text-generation-launcher, but must be set to False. 34 | trust-remote-code (`bool`): 35 | Kept for compatibility with text-generation-launcher. Ignored. 36 | uds_path (`Union[Path, str]`): 37 | The local path on which the server will expose its google RPC services. 38 | logger_level (`str`): 39 | The server logger level. Defaults to *INFO*. 40 | json_output (`bool`): 41 | Use JSON format for log serialization. 42 | otlp_service_name (`str`): 43 | The name of the OTLP service. For now it is ignored. 44 | max_input_tokens (`Optional[int]`): 45 | The maximum number of tokens allowed in the input. For now it is ignored. 46 | """ 47 | if sharded: 48 | raise ValueError("Sharding is not supported.") 49 | # Remove default handler 50 | logger.remove() 51 | logger.add( 52 | sys.stdout, 53 | format="{message}", 54 | filter="text_generation_server", 55 | level=logger_level, 56 | serialize=json_output, 57 | backtrace=True, 58 | diagnose=False, 59 | ) 60 | 61 | if trust_remote_code is not None: 62 | logger.warning("'trust_remote_code' argument is not supported and will be ignored.") 63 | 64 | # TODO: these two parameters are used when the server is started, but they are not used yet, so just inform the 65 | # user about that. 66 | logger.info("'otlp_service_name' argument is not supported and will be ignored.") 67 | 68 | # This is a workaround to pass the logger level to other threads, it's only used in 69 | # Pytorch/XLA generator. 70 | os.environ["LOGGER_LEVEL_GENERATOR"] = logger_level 71 | 72 | # Import here after the logger is added to log potential import exceptions 73 | from optimum.tpu.model import fetch_model 74 | 75 | from .server import serve 76 | 77 | # Read environment variables forwarded by the launcher 78 | max_batch_size = int(os.environ.get("MAX_BATCH_SIZE", "4")) 79 | max_total_tokens = int(os.environ.get("MAX_TOTAL_TOKENS", "64")) 80 | 81 | # Start the server 82 | model_path = fetch_model(model_id, revision) 83 | serve( 84 | model_path, 85 | revision=revision, 86 | max_batch_size=max_batch_size, 87 | max_sequence_length=max_total_tokens, 88 | max_input_tokens=max_input_tokens, 89 | uds_path=uds_path 90 | ) 91 | 92 | 93 | @app.command() 94 | def download_weights( 95 | model_id: str, 96 | revision: Optional[str] = None, 97 | logger_level: str = "INFO", 98 | json_output: bool = False, 99 | auto_convert: Optional[bool] = None, 100 | extension: Optional[str] = None, 101 | trust_remote_code: Optional[bool] = None, 102 | merge_lora: Optional[bool] = None, 103 | ): 104 | """Download the model weights. 105 | 106 | This command will be called by text-generation-launcher before serving the model. 107 | """ 108 | # Remove default handler 109 | logger.remove() 110 | logger.add( 111 | sys.stdout, 112 | format="{message}", 113 | filter="text_generation_server", 114 | level=logger_level, 115 | serialize=json_output, 116 | backtrace=True, 117 | diagnose=False, 118 | ) 119 | 120 | if extension is not None: 121 | logger.warning("'extension' argument is not supported and will be ignored.") 122 | if trust_remote_code is not None: 123 | logger.warning("'trust_remote_code' argument is not supported and will be ignored.") 124 | if auto_convert is not None: 125 | logger.warning("'auto_convert' argument is not supported and will be ignored.") 126 | if merge_lora is not None: 127 | logger.warning("'merge_lora' argument is not supported and will be ignored.") 128 | 129 | # Import here after the logger is added to log potential import exceptions 130 | from optimum.tpu.model import fetch_model 131 | 132 | fetch_model(model_id, revision) 133 | -------------------------------------------------------------------------------- /docs/source/howto/deploy_instance_on_ie.mdx: -------------------------------------------------------------------------------- 1 | # How to Deploy a Model on Inference Endpoint (IE) for Serving using TPUs 2 | 3 | Inference Endpoints (IE) is a solution to serve generation using supported models on TPU. It does not require setting up a separate GCP account, and it will offer some pre-configured settings to serve models with Optimum's TPU TGI. 4 | 5 | You can deploy any of our supported models on Inference Endpoint (see list of supported models). 6 | Inference Endpoints offer secure production environments by setting up a TGI server that can auto-scale based on demand. 7 | 8 | We have optimized Inference Endpoints on TPU to ensure each model achieves optimal performance. 9 | 10 | ## 1. Create a New Endpoint 11 | 12 | Click the "New Endpoint" button to get started at https://endpoints.huggingface.co 13 | 14 | ![Create new endpoint](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/ie_create_new_endpoint.png) 15 | 16 | ## 2. Configure the New Endpoint 17 | 18 | Configure your endpoint by selecting from the list of TPU-supported models. 19 | Note: If you choose a model unsupported on TPU, the TPU option will not be visible. This is by design to prevent starting unsupported models on TPU. 20 | 21 | Let's use google/gemma-2b-it as an example. The TPU tab is selectable, so we can confirm TPU compatibility. Note that this model is unavailable on CPU, as indicated by the greyed-out CPU option. 22 | 23 | ![Configure endpoint](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/ie_config_new_endpoint.png) 24 | 25 | Note: We automatically select the optimal hardware and configuration for each model. For google/gemma-2b-it, being a smaller model, we select a 1-chip TPU (TPU v5e-1) since 16GB of HBM (High Bandwidth Memory) is sufficient to serve the 2B model. This ensures cost-efficient resource allocation without unnecessary computing expenses. 26 | 27 | We extensively test and optimize TGI configurations to maximize hardware performance. Parameters such as Max Input Length, Max Number of Tokens, and Max Batch Prefill Tokens are automatically configured based on each model's requirements and are set manually by the optimum-tpu team. If you set the model to google/gemma-7b-it, you will see that those values in "container configuration" are different and optimized for the 7b model instead. 28 | 29 | **Note**: You can set up advanced TGI features like quantization by accessing the environment variables section of the interface. You can, for example, set "key:QUANTIZATION" and "value:1" to enable quantization. You can view all those advanced TGI options in our advanced TGI serving guide (./advance-tgi-config) 30 | 31 | Once you've completed the configuration, click the "Create Endpoint" button. 32 | 33 | ## 3. Using Your Endpoint 34 | 35 | The endpoint requires initialization, during which you can monitor the logs. In the logs section, you will observe the model undergoing warmup to compile for optimal performance. Endpoint startup typically takes between 5 to 30 minutes, depending on the model size. This warmup period triggers multiple compilations to ensure peak serving performance. 36 | 37 | ![IE init](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/ie_endpoint_initalizing.png) 38 | 39 | After the endpoint completes "Initializing," you can query it through the GUI or API. 40 | 41 | ![IE running](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/ie_endpoint_running.png) 42 | 43 | 44 | Query your endpoint using either the playground or curl commands. 45 | 46 | ### 3.1 Query via Playground 47 | 48 | Use the GUI to write and execute queries on the TPU instance. 49 | 50 | ![IE playground openAI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/ie_playground_openapi.png) 51 | 52 | ### 3.2 Query via curl 53 | 54 | Alternatively, use curl commands to query the endpoint. 55 | 56 | ![IE playground curl](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/optimum/tpu/ie_playground_curl.png) 57 | 58 | ```bash 59 | curl "https://{INSTANCE_ID}.{REGION}.gcp.endpoints.huggingface.cloud/v1/chat/completions" \ 60 | -X POST \ 61 | -H "Authorization: Bearer hf_XXXXX" \ 62 | -H "Content-Type: application/json" \ 63 | -d '{ 64 | "model": "tgi", 65 | "messages": [ 66 | { 67 | "role": "user", 68 | "content": "What is deep learning?" 69 | } 70 | ], 71 | "max_tokens": 150, 72 | "stream": true 73 | }' 74 | ``` 75 | 76 | You will need to replace {INSTANCE_ID} and {REGION} with the value from your own deployement. 77 | 78 | ## Next Steps 79 | 80 | - There are numerous ways to interact with your new inference endpoints. Review the inference endpoint documentation to explore different options: 81 | https://huggingface.co/docs/inference-endpoints/index 82 | - Consult our advanced parameter guide for TGI to learn about advanced TGI options you can use on inference endpoint (./howto/advanced-tgi-serving) 83 | - You can explore the full list of TPU-compatible models on the [Inference Endpoints TPU catalog page](https://endpoints.huggingface.co/catalog?accelerator=tpu) -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | SHELL := /bin/bash 15 | CURRENT_DIR = $(shell pwd) 16 | DEFAULT_CLONE_URL := https://github.com/huggingface/optimum-tpu.git 17 | # If CLONE_URL is empty, revert to DEFAULT_CLONE_URL 18 | REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL)) 19 | 20 | .PHONY: build_dist style style_check clean 21 | 22 | TGI_VERSION ?= v3.0.0 23 | 24 | rwildcard=$(wildcard $1) $(foreach d,$1,$(call rwildcard,$(addsuffix /$(notdir $d),$(wildcard $(dir $d)*)))) 25 | 26 | VERSION := $(shell awk '/__version__ = "(.*)"/{print $$3}' optimum/tpu/version.py | sed 's/"//g') 27 | 28 | PACKAGE_DIST = dist/optimum-tpu-$(VERSION).tar.gz 29 | PACKAGE_WHEEL = dist/optimum_tpu-$(VERSION)-py3-none-any.whl 30 | PACKAGE_PYTHON_FILES = $(call rwildcard, optimum/*.py) 31 | PACKAGE_FILES = $(PACKAGE_PYTHON_FILES) \ 32 | setup.cfg \ 33 | pyproject.toml \ 34 | README.md \ 35 | MANIFEST.in 36 | 37 | # Package build recipe 38 | $(PACKAGE_DIST) $(PACKAGE_WHEEL): $(PACKAGE_FILES) 39 | python -m build 40 | 41 | clean: 42 | rm -rf dist deps 43 | make -C text-generation-inference/server/ clean 44 | 45 | # normal usage: make tpu-tgi 46 | # ci usage: make tpu-tgi NETWORK=host, to build the docker image with the network host option 47 | tpu-tgi: 48 | docker build --rm -f text-generation-inference/docker/Dockerfile \ 49 | --build-arg VERSION=$(VERSION) \ 50 | --build-arg TGI_VERSION=$(TGI_VERSION) \ 51 | --ulimit nofile=100000:100000 \ 52 | $(if $(NETWORK),--network $(NETWORK),) \ 53 | -t huggingface/optimum-tpu:$(VERSION)-tgi . 54 | docker tag huggingface/optimum-tpu:$(VERSION)-tgi huggingface/optimum-tpu:latest 55 | 56 | tpu-tgi-ie: 57 | docker build --rm -f text-generation-inference/docker/Dockerfile \ 58 | --target inference-endpoint \ 59 | --build-arg VERSION=$(VERSION) \ 60 | --build-arg TGI_VERSION=$(TGI_VERSION) \ 61 | --ulimit nofile=100000:100000 \ 62 | -t huggingface/optimum-tpu:$(VERSION)-tgi . 63 | docker tag huggingface/optimum-tpu:$(VERSION)-tgi huggingface/optimum-tpu:latest-ie 64 | 65 | tpu-tgi-gcp: 66 | docker build --rm -f text-generation-inference/docker/Dockerfile \ 67 | --target google-cloud-containers \ 68 | --build-arg ENABLE_GCP_INTEGRATION=1 \ 69 | --ulimit nofile=100000:100000 \ 70 | -t huggingface/optimum-tpu:$(VERSION)-tgi-gcp . 71 | docker tag huggingface/optimum-tpu:$(VERSION)-tgi-gcp huggingface/optimum-tpu:latest-gcp 72 | 73 | # Run code quality checks 74 | style_check: 75 | ruff check . 76 | 77 | style: 78 | ruff check . --fix 79 | 80 | # Utilities to release to PyPi 81 | build_dist_install_tools: 82 | python -m pip install build 83 | python -m pip install twine 84 | 85 | build_dist: ${PACKAGE_DIST} ${PACKAGE_WHEEL} 86 | 87 | pypi_upload: ${PACKAGE_DIST} ${PACKAGE_WHEEL} 88 | python -m twine upload ${PACKAGE_DIST} ${PACKAGE_WHEEL} 89 | 90 | # Tests 91 | test_installs: 92 | python -m pip install -r requirements.txt 93 | python -m pip install .[tests] -f https://storage.googleapis.com/libtpu-releases/index.html 94 | 95 | tests: test_installs 96 | python -m pytest -sv tests 97 | 98 | # Stand-alone TGI server for unit tests outside of TGI container 99 | tgi_server: 100 | python -m pip install -r text-generation-inference/server/build-requirements.txt 101 | make -C text-generation-inference/server clean 102 | VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server 103 | 104 | jetstream_requirements: test_installs 105 | python optimum/tpu/cli.py install-jetstream-pytorch --yes 106 | 107 | tgi_test_jetstream: test_installs jetstream_requirements tgi_server 108 | find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \ 109 | -exec python -m pip install --force-reinstall {} \; 110 | python -m pytest -sv text-generation-inference/tests -m jetstream 111 | 112 | tgi_test: test_installs tgi_server 113 | find text-generation-inference -name "text_generation_server-$(VERSION)-py3-none-any.whl" \ 114 | -exec python -m pip install --force-reinstall {} \; 115 | python -m pytest -sv text-generation-inference/tests -m torch_xla 116 | 117 | tgi_docker_test: 118 | python -m pip install -r text-generation-inference/integration-tests/requirements.txt 119 | python -m pytest -sv text-generation-inference/integration-tests 120 | 121 | preview_doc: 122 | doc-builder preview optimum-tpu docs/source --not_python_module 123 | -------------------------------------------------------------------------------- /examples/text-generation/generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import argparse 4 | import datetime 5 | import os 6 | import platform 7 | import time 8 | from typing import List 9 | 10 | import torch 11 | import torch_xla.core.xla_model as xm 12 | from transformers import AutoTokenizer, StaticCache 13 | 14 | from optimum.tpu.modeling import AutoModelForCausalLM 15 | 16 | 17 | os.environ["PJRT_DEVICE"] = "TPU" 18 | 19 | 20 | def sample_greedy(logits): 21 | next_logits = logits[:, -1] 22 | next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() 23 | return next_token_id 24 | 25 | 26 | def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): 27 | logits = model( 28 | cur_token, 29 | position_ids=input_pos, 30 | cache_position=cache_position, 31 | return_dict=False, 32 | use_cache=True, 33 | past_key_values=past_key_values, 34 | )[0] 35 | new_token = sample_greedy(logits) 36 | return new_token 37 | 38 | 39 | def conditional_compile(func): 40 | if "DBG_COMPILE" in os.environ: 41 | compiled = torch.compile(func, backend="openxla") 42 | return compiled 43 | return func 44 | 45 | 46 | def summary(values: List[float]): 47 | values.sort() 48 | n = len(values) 49 | if n % 2 == 0: 50 | median = (values[n // 2 - 1] + values[n // 2]) / 2 51 | else: 52 | median = values[n // 2] 53 | total = sum(values) 54 | mean = sum(values) / n 55 | print(f"Decode time: {total}, average: {mean}, median: {median}") 56 | 57 | 58 | def main(): 59 | parser = argparse.ArgumentParser(description="Text generation example") 60 | parser.add_argument("--model_id", type=str, 61 | default="google/gemma-2b", 62 | help="Model ID (e.g.: google/gemma-2b, mistralai/Mistral-7B-v0.3)") 63 | parser.add_argument("--max_new_tokens", type=int, default=20, help="Number of tokens to generate") 64 | parser.add_argument("--max_cache_length", type=int, default=256, help="Maximum cache length for the model") 65 | args = parser.parse_args() 66 | 67 | prg_start = time.time() 68 | print(f"⏳ Loading model {args.model_id}...") 69 | model_id = args.model_id 70 | torch_dtype = torch.bfloat16 71 | 72 | model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype) 73 | device = model.device 74 | model = model.eval() 75 | print(f"✅ Model loaded in {time.time() - prg_start} seconds.") 76 | 77 | tokenizer = AutoTokenizer.from_pretrained(model_id) 78 | # Set pad token for cases where it is None, e.g. for Mistral 79 | if tokenizer.pad_token_id is None: 80 | tokenizer.pad_token = tokenizer.eos_token 81 | prompts = ["Here's a funny thing:", "Once upon a time,"] 82 | inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) 83 | batch_size, sequence_length = inputs["input_ids"].shape 84 | max_cache_length = 1024 85 | max_new_tokens = args.max_new_tokens 86 | 87 | # setup static cache 88 | past_key_values = StaticCache( 89 | config=model.config, 90 | max_batch_size=batch_size, 91 | max_cache_len=max_cache_length, 92 | device=model.device, 93 | dtype=model.dtype, 94 | ) 95 | start = time.time() 96 | cache_position = torch.arange(sequence_length, device=device) 97 | generated_ids = torch.zeros( 98 | (batch_size, sequence_length + max_new_tokens + 1), 99 | dtype=torch.int, 100 | device=device, 101 | ) 102 | generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int) 103 | 104 | # prefill here 105 | attention_mask = inputs["attention_mask"] 106 | pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) 107 | logits = model( 108 | **inputs, 109 | cache_position=cache_position, 110 | return_dict=False, 111 | use_cache=True, 112 | position_ids=pos_ids, 113 | past_key_values=past_key_values, 114 | )[0] 115 | next_token = sample_greedy(logits) 116 | xm.mark_step() 117 | generated_ids[:, sequence_length] = next_token[:, 0] 118 | end = time.time() 119 | print(f"Prefill took {end - start} seconds.") 120 | 121 | pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1 122 | 123 | model = conditional_compile(model) 124 | cache_position = torch.tensor([sequence_length], device=device) 125 | decode_times = [] 126 | for i in range(max_new_tokens): 127 | step_start = time.time() 128 | next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, past_key_values) 129 | cache_position += 1 130 | generated_ids[:, cache_position] = next_token 131 | pos_ids += 1 132 | xm.mark_step() 133 | step_end = time.time() 134 | step_time = step_end - step_start 135 | decode_times.append(step_time) 136 | print(f"Step {i} took {step_time} seconds.") 137 | summary(decode_times) 138 | 139 | print(f"Decoding start at {datetime.datetime.now()}") 140 | 141 | decoded_texts = tokenizer.batch_decode(generated_ids) 142 | for i, text in enumerate(decoded_texts): 143 | print(i, text) 144 | 145 | end = time.time() 146 | print(f"Program run in {end - prg_start} seconds. Device: {device} System: {platform.system()}") 147 | 148 | 149 | if __name__ == "__main__": 150 | with torch.no_grad(): 151 | main() 152 | -------------------------------------------------------------------------------- /text-generation-inference/README.md: -------------------------------------------------------------------------------- 1 | # Text-generation-inference docker image for Pytorch/XLA 2 | 3 | This docker image integrates into a base image: 4 | 5 | - the [Text Generation Inference](https://github.com/huggingface/text-generation-inference) launcher and scheduling front-end, 6 | - an XLA specific inference server for text-generation. 7 | 8 | ## Features 9 | 10 | The basic features of the [Text Generation Inference](https://github.com/huggingface/text-generation-inference) product are supported: 11 | 12 | - continuous batching, 13 | - token streaming, 14 | - greedy search and multinomial sampling using [transformers](https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation). 15 | 16 | The main differences with the standard service for CUDA and CPU backends are that: 17 | 18 | - the service uses a single internal static batch, 19 | - new requests are inserted in the static batch during prefill, 20 | - the static KV cache is rebuilt entirely during prefill. 21 | 22 | ## License 23 | 24 | This docker image is released under [HFOIL 1.0](https://github.com/huggingface/text-generation-inference/blob/bde25e62b33b05113519e5dbf75abda06a03328e/LICENSE). 25 | 26 | HFOIL stands for Hugging Face Optimized Inference License, and it has been specifically designed for our optimized inference solutions. While the source code remains accessible, HFOIL is not a true open source license because we added a restriction: to sell a hosted or managed service built on top of TGI, we require a separate agreement. 27 | 28 | Please refer to [this reference documentation](https://github.com/huggingface/text-generation-inference/issues/726) to see if the HFOIL 1.0 restrictions apply to your deployment. 29 | 30 | ## Deploy the service 31 | 32 | The service is launched simply by running the tpu-tgi container with two sets of parameters: 33 | 34 | ``` 35 | docker run ghcr.io/huggingface/tpu-tgi:latest 36 | ``` 37 | 38 | - system parameters are used to map ports, volumes and devices between the host and the service, 39 | - service parameters are forwarded to the `text-generation-launcher`. 40 | 41 | ### Common system parameters 42 | 43 | Finally, you might want to export the `HF_TOKEN` if you want to access gated repository. 44 | 45 | Here is an example of a service instantiation on a single host TPU: 46 | 47 | ``` 48 | docker run -p 8080:80 \ 49 | --net=host --privileged \ 50 | -v $(pwd)/data:/data \ 51 | -e HF_TOKEN=${HF_TOKEN} \ 52 | ghcr.io/huggingface/tpu-tgi:latest \ 53 | 54 | ``` 55 | 56 | 57 | 58 | ### Using a standard model from the 🤗 [HuggingFace Hub](https://huggingface.co/models) 59 | 60 | 61 | The snippet below shows how you can deploy a service from a hub standard model: 62 | 63 | ``` 64 | docker run -p 8080:80 \ 65 | --net=host --privileged \ 66 | -v $(pwd)/data:/data \ 67 | -e HF_TOKEN=${HF_TOKEN} \ 68 | -e MAX_BATCH_SIZE=4 \ 69 | -e HF_SEQUENCE_LENGTH=1024 \ 70 | ghcr.io/huggingface/tpu-tgi:latest \ 71 | --model-id mistralai/Mistral-7B-v0.1 \ 72 | --max-concurrent-requests 1 \ 73 | --max-input-length 512 \ 74 | --max-total-tokens 1024 \ 75 | --max-batch-prefill-tokens 512 \ 76 | --max-batch-total-tokens 1024 77 | ``` 78 | 79 | 80 | ### Choosing service parameters 81 | 82 | Use the following command to list the available service parameters: 83 | 84 | ``` 85 | docker run ghcr.io/huggingface/tpu-tgi --help 86 | ``` 87 | 88 | The configuration of an inference endpoint is always a compromise between throughput and latency: serving more requests in parallel will allow a higher throughput, but it will increase the latency. 89 | 90 | The models for now work with static input dimensions `[batch_size, max_length]`. 91 | 92 | It leads to a maximum number of tokens of `max_tokens = batch_size * max_length`. 93 | 94 | This adds several restrictions to the following parameters: 95 | 96 | - `--max-concurrent-requests` must be set to `batch size`, 97 | - `--max-input-length` must be lower than `max_length`, 98 | - `--max-total-tokens` must be set to `max_length` (it is per-request), 99 | - `--max-batch-prefill-tokens` must be set to `batch_size * max_input_length`, 100 | - `--max-batch-total-tokens` must be set to `max_tokens`. 101 | 102 | ### Choosing the correct batch size 103 | 104 | As seen in the previous paragraph, model static batch size has a direct influence on the endpoint latency and throughput. 105 | 106 | Please refer to [text-generation-inference](https://github.com/huggingface/text-generation-inference) for optimization hints. 107 | 108 | Note that the main constraint is to be able to fit the model for the specified `batch_size` within the total device memory available 109 | on your instance. 110 | 111 | ## Query the service 112 | 113 | You can query the model using either the `/generate` or `/generate_stream` routes: 114 | 115 | ``` 116 | curl 127.0.0.1:8080/generate \ 117 | -X POST \ 118 | -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ 119 | -H 'Content-Type: application/json' 120 | ``` 121 | 122 | ``` 123 | curl 127.0.0.1:8080/generate_stream \ 124 | -X POST \ 125 | -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ 126 | -H 'Content-Type: application/json' 127 | ``` 128 | 129 | ## Build your own image 130 | 131 | The image must be built from the top directory 132 | 133 | ``` 134 | make tpu-tgi 135 | ``` 136 | -------------------------------------------------------------------------------- /docs/source/index.mdx: -------------------------------------------------------------------------------- 1 | 16 | 17 | # 🤗 Optimum TPU 18 | 19 | 20 | **🚧 Optimum-TPU is now in maintenance mode.** 21 | 22 | We’ll continue to welcome community contributions for minor bug fixes, documentation improvements, and lightweight maintenance tasks. 23 | 24 | While this project is no longer under active development, you can continue exploring TPU solutions with: 25 | 26 | - [TPU inference](./tutorials/inference_on_tpu) for inference 27 | - [🤗 Accelerate](https://github.com/huggingface/accelerate) for training 28 | 29 | 30 | 31 | 32 | Optimum TPU provides all the necessary machinery to leverage and optimize AI workloads running on [Google Cloud TPU devices](https://cloud.google.com/tpu/docs). Optimum-TPU is a HuggingFace solution to optimize HuggingFace products for the TPU platform. This allows users to use HuggingFace features and easy-to-use libraries on TPU with the best performance. We currently optimize transformers and TGI and integrate [HuggingFace hub](https://huggingface.co/models) so you can access HuggingFace's large library of models. 33 | 34 | If you are here to start using HuggingFace products on TPU, then you are in the right place 35 | 36 | The API provides the overall same user-experience as HuggingFace transformers with the minimum amount of changes required to target performance for inference and training. 37 | 38 | Optimum TPU is meant to reduce as much as possible the friction in order to leverage Google Cloud TPU accelerators. 39 | As such, we provide a pip installable package to make sure everyone can get easily started. 40 | 41 | ```bash 42 | pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html 43 | ``` 44 | 45 | ## Why Choose TPUs 46 | TPUs excel at large-scale machine learning workloads with matrix computations, extended training periods, and large batch sizes. In contrast, GPUs offer more flexibility for models with custom operations or mixed CPU/GPU workloads. TPUs aren't ideal for workloads needing frequent branching, high-precision arithmetic, or custom training loop operations. More information can be found at https://cloud.google.com/tpu/docs/intro-to-tpu#when_to_use_tpus 47 | 48 | ## Why Choose Optimum-TPU 49 | Optimum-TPU serves as the bridge between the HuggingFace ecosystem and Google Cloud TPU hardware. It dramatically simplifies what would otherwise be a complex integration process, providing an intuitive interface that abstracts away TPU-specific implementation details while maintaining high performance. Through automated optimizations, efficient batching strategies, intelligent memory management and more, Optimum-TPU ensures your models run at peak efficiency on TPU hardware. The framework's deep integration with the HuggingFace Hub catalog of models and datasets enables easy deployment and fine-tuning of state-of-the-art models with the familiar ease of use of HuggingFace libraries while maximizing TPU hardware capabilities. 50 | 51 | -------------------------------------------------------------------------------- /text-generation-inference/server/text_generation_server/jetstream_pt_support/models/llama_model_exportable_hf.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import math 3 | from contextlib import contextmanager 4 | from functools import partial 5 | 6 | import torch 7 | from jetstream_pt.third_party.llama import model_exportable 8 | from jetstream_pt.third_party.llama.model_exportable import Transformer, model_args 9 | from transformers import GenerationConfig, GenerationMixin, LlamaConfig 10 | 11 | 12 | # TODO: it would be better to have RoPE scaling code in Jetstream Pytorch, but until that is not done, 13 | # we add it here. Note that this is the reason why we define a new class RopeScalingArgs, instead of using the 14 | # config from transformers. 15 | @dataclasses.dataclass 16 | class RopeScalingArgs: 17 | """Rope scaling configuration parameters.""" 18 | 19 | factor: float = 8.0 20 | low_freq_factor: float = 1.0 21 | high_freq_factor: float = 4.0 22 | original_max_position_embeddings: int = 8192 23 | 24 | 25 | def apply_scaling(freqs: torch.Tensor, config: RopeScalingArgs): 26 | # Values obtained from grid search 27 | scale_factor = config.factor 28 | low_freq_factor = config.low_freq_factor 29 | high_freq_factor = config.high_freq_factor 30 | old_context_len = config.original_max_position_embeddings 31 | 32 | low_freq_wavelen = old_context_len / low_freq_factor 33 | high_freq_wavelen = old_context_len / high_freq_factor 34 | new_freqs = [] 35 | for freq in freqs: 36 | wavelen = 2 * math.pi / freq 37 | if wavelen < high_freq_wavelen: 38 | new_freqs.append(freq) 39 | elif wavelen > low_freq_wavelen: 40 | new_freqs.append(freq / scale_factor) 41 | else: 42 | assert low_freq_wavelen != high_freq_wavelen 43 | smooth = (old_context_len / wavelen - low_freq_factor) / ( 44 | high_freq_factor - low_freq_factor 45 | ) 46 | new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) 47 | return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) 48 | 49 | 50 | def precompute_freqs_cis( 51 | dim: int, 52 | end: int, 53 | theta: float = 10000.0, 54 | rope_scaling_config: RopeScalingArgs = None, 55 | ): 56 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 57 | t = torch.arange(end, device=freqs.device, dtype=torch.float32) 58 | if rope_scaling_config is not None: 59 | freqs = apply_scaling(freqs, rope_scaling_config) 60 | freqs = torch.outer(t, freqs) 61 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 62 | return freqs_cis 63 | 64 | 65 | @contextmanager 66 | def patch_precompute_freqs_cis(rope_scaling_js: RopeScalingArgs): 67 | # NOTE: This is a workaround to pass the rope scaling configuration when it is called in the original 68 | # Jetstream/Pytorch model. The function is monkey-patched to include the rope scaling configuration. 69 | original_precompute_freqs_cis = model_exportable.precompute_freqs_cis 70 | precompute_freqs_cis_partial = partial(precompute_freqs_cis, rope_scaling_config=rope_scaling_js) 71 | model_exportable.precompute_freqs_cis = precompute_freqs_cis_partial 72 | 73 | yield 74 | 75 | # Original function is restored. 76 | model_exportable.precompute_freqs_cis = original_precompute_freqs_cis 77 | 78 | 79 | class TransformerHf(Transformer, GenerationMixin): 80 | """Transformer module that uses HF LlamaConfig instead of Jetstream Pytorch ModelArgs + device. 81 | 82 | Note that this class also derives from GenerationMixin, so that we can use its methods. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | config: LlamaConfig, 88 | device, 89 | env, 90 | ): 91 | self.config = config 92 | self.generation_config = GenerationConfig.from_model_config(config) 93 | 94 | # NOTE: these parameters are deduced from the config's intermediate_size and hidden_size, so to be compatible 95 | # with the original Jestream/Pytorch model. 96 | ffn_dim_multiplier = config.intermediate_size / int(8 * config.hidden_size / 3) 97 | multiple_of = 1 98 | 99 | if config.mlp_bias: 100 | raise ValueError("MLP bias is not supported in the on Jetstream Pytorch." 101 | + "If your model requires it, you can open an issue.") 102 | 103 | rope_scaling_js = None 104 | rope_scaling = config.rope_scaling 105 | # The original Llama2 and Llama3 models do not have rope scaling configuration, while newer models do. 106 | if rope_scaling is not None: 107 | # Some models use "type" instead of "rope_type" in the configuration for historical reasons. 108 | rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) 109 | if rope_type != "llama3": 110 | raise ValueError(f"Unsupported rope type {rope_type} in rope scaling configuration.") 111 | 112 | rope_scaling_js = RopeScalingArgs( 113 | factor=rope_scaling["factor"], 114 | low_freq_factor=rope_scaling["low_freq_factor"], 115 | high_freq_factor=rope_scaling["high_freq_factor"], 116 | original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], 117 | ) 118 | 119 | args = model_args.ModelArgs( 120 | dim=config.hidden_size, 121 | n_layers=config.num_hidden_layers, 122 | n_heads=config.num_attention_heads, 123 | n_kv_heads=config.num_key_value_heads, 124 | vocab_size=config.vocab_size, 125 | multiple_of=multiple_of, 126 | ffn_dim_multiplier=ffn_dim_multiplier, 127 | norm_eps=config.rms_norm_eps, 128 | max_seq_len=env.cache_len, 129 | bf16_enable=env.bf16_enable, 130 | rope_theta=config.rope_theta, 131 | ) 132 | args.device = device 133 | 134 | with patch_precompute_freqs_cis(rope_scaling_js): 135 | super().__init__(args, env) 136 | 137 | 138 | @classmethod 139 | def from_config(cls, config, env): 140 | device = "meta" 141 | model = cls(config, device, env) 142 | return model 143 | -------------------------------------------------------------------------------- /text-generation-inference/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Fetch and extract the TGI sources 2 | FROM alpine AS tgi 3 | # TGI version 3.0.0 by default 4 | ARG TGI_VERSION=v3.0.0 5 | RUN test -n ${TGI_VERSION:?} 6 | RUN mkdir -p /tgi 7 | ADD https://github.com/huggingface/text-generation-inference/archive/${TGI_VERSION}.tar.gz /tgi/sources.tar.gz 8 | RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1 9 | 10 | # Build cargo components (adapted from TGI original Dockerfile) 11 | # Note: we cannot use the cargo-chef base image as it uses python 3.11 12 | FROM ubuntu:22.04 AS chef 13 | 14 | RUN apt-get update -y \ 15 | && apt-get install -y --no-install-recommends \ 16 | curl ca-certificates build-essential \ 17 | && rm -rf /var/lib/apt/lists/* \ 18 | && apt-get clean 19 | 20 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.80.1 --profile minimal -y 21 | ENV PATH="/root/.cargo/bin:${PATH}" 22 | RUN cargo install cargo-chef --locked 23 | 24 | WORKDIR /usr/src 25 | 26 | ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse 27 | 28 | FROM chef AS planner 29 | COPY text-generation-inference/Cargo.toml Cargo.toml 30 | COPY --from=tgi /tgi/Cargo.lock Cargo.lock 31 | COPY --from=tgi /tgi/rust-toolchain.toml rust-toolchain.toml 32 | COPY --from=tgi /tgi/proto proto 33 | COPY --from=tgi /tgi/router router 34 | COPY --from=tgi /tgi/backends backends 35 | COPY --from=tgi /tgi/launcher launcher 36 | RUN cargo chef prepare --recipe-path recipe.json 37 | 38 | FROM chef AS builder 39 | ARG ENABLE_GOOGLE_FEATURE 40 | RUN echo "Google Feature Status: ${ENABLE_GOOGLE_FEATURE}" 41 | 42 | RUN apt-get update -y \ 43 | && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 44 | unzip python3-dev libssl-dev pkg-config \ 45 | && rm -rf /var/lib/apt/lists/* \ 46 | && apt-get clean 47 | 48 | RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ 49 | curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ 50 | unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ 51 | unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ 52 | rm -f $PROTOC_ZIP 53 | 54 | COPY text-generation-inference/Cargo.toml Cargo.toml 55 | COPY --from=planner /usr/src/recipe.json recipe.json 56 | RUN cargo chef cook --profile release-opt --recipe-path recipe.json 57 | 58 | COPY --from=tgi /tgi/Cargo.lock Cargo.lock 59 | COPY --from=tgi /tgi/rust-toolchain.toml rust-toolchain.toml 60 | COPY --from=tgi /tgi/proto proto 61 | COPY --from=tgi /tgi/router router 62 | COPY --from=tgi /tgi/backends backends 63 | COPY --from=tgi /tgi/launcher launcher 64 | RUN if [ -n "$ENABLE_GOOGLE_FEATURE" ]; then \ 65 | cargo build --profile release-opt --features google; \ 66 | else \ 67 | cargo build --profile release-opt; \ 68 | fi 69 | 70 | # Python base image 71 | FROM ubuntu:22.04 AS base 72 | 73 | RUN apt-get update -y \ 74 | && apt-get install -y --no-install-recommends \ 75 | python3-pip \ 76 | python3-setuptools \ 77 | python-is-python3 \ 78 | git \ 79 | && rm -rf /var/lib/apt/lists/* \ 80 | && apt-get clean 81 | RUN pip3 --no-cache-dir install --upgrade pip 82 | 83 | ARG ENABLE_GOOGLE_FEATURE 84 | ARG VERSION='0.2.3.dev0' 85 | RUN test -n ${VERSION:?} 86 | 87 | FROM base AS optimum-tpu-installer 88 | 89 | COPY . /tmp/src 90 | 91 | RUN if [ -n "$ENABLE_GOOGLE_FEATURE" ]; then \ 92 | # If we are building for GCP, we need to clone the optimum-tpu repo as this is built from the huggingface/Google-Cloud-Containers repository and not the huggingface/optimum-tpu repository 93 | git clone https://github.com/huggingface/optimum-tpu.git /opt/optimum-tpu && \ 94 | cd /opt/optimum-tpu && git checkout v${VERSION}; \ 95 | fi && \ 96 | # Check if the optimum-tpu repo is cloned properly 97 | cp -a /tmp/src /opt/optimum-tpu && \ 98 | if [ ! -d "/opt/optimum-tpu/optimum" ]; then \ 99 | echo "Error: Building from incorrect repository. This build must be run from optimum-tpu repo. If building from google-cloud-containers repo, set ENABLE_GOOGLE_FEATURE=1 to automatically clone optimum-tpu" && \ 100 | exit 1; \ 101 | fi 102 | 103 | 104 | # Python server build image 105 | FROM base AS pyserver 106 | 107 | RUN apt-get update -y \ 108 | && apt-get install -y --no-install-recommends \ 109 | make \ 110 | python3-venv \ 111 | && rm -rf /var/lib/apt/lists/* \ 112 | && apt-get clean 113 | 114 | RUN install -d /pyserver 115 | WORKDIR /pyserver 116 | COPY --from=optimum-tpu-installer /opt/optimum-tpu/text-generation-inference/server server 117 | COPY --from=tgi /tgi/proto proto 118 | RUN pip3 install -r server/build-requirements.txt 119 | RUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto VERSION=${VERSION} make -C server gen-server 120 | 121 | # TPU base image (used for deployment) 122 | FROM base AS tpu_base 123 | 124 | ARG VERSION=${VERSION} 125 | 126 | # Install system prerequisites 127 | RUN apt-get update -y \ 128 | && apt-get install -y --no-install-recommends \ 129 | libpython3.10 \ 130 | git \ 131 | gnupg2 \ 132 | wget \ 133 | curl \ 134 | && rm -rf /var/lib/apt/lists/* \ 135 | && apt-get clean 136 | 137 | # Update pip 138 | RUN pip install --upgrade pip 139 | 140 | # Install HuggingFace packages 141 | ARG TRANSFORMERS_VERSION='4.46.3' 142 | ARG ACCELERATE_VERSION='1.1.1' 143 | ARG SAFETENSORS_VERSION='0.4.5' 144 | 145 | ARG ENABLE_GOOGLE_FEATURE 146 | 147 | ENV HF_HUB_ENABLE_HF_TRANSFER=1 148 | ENV VERSION=${VERSION} 149 | 150 | ENV PORT=${ENABLE_GOOGLE_FEATURE:+8080} 151 | ENV PORT=${PORT:-80} 152 | 153 | ENV HF_HOME=${ENABLE_GOOGLE_FEATURE:+/tmp} 154 | ENV HF_HOME=${HF_HOME:-/data} 155 | 156 | # Install requirements for optimum-tpu, then for TGI then optimum-tpu 157 | RUN python3 -m pip install hf_transfer safetensors==${SAFETENSORS_VERSION} typer 158 | COPY --from=optimum-tpu-installer /opt/optimum-tpu /opt/optimum-tpu 159 | RUN python3 /opt/optimum-tpu/optimum/tpu/cli.py install-jetstream-pytorch --yes 160 | RUN python3 -m pip install -e /opt/optimum-tpu \ 161 | -f https://storage.googleapis.com/libtpu-releases/index.html 162 | 163 | 164 | # Install router 165 | COPY --from=builder /usr/src/target/release-opt/text-generation-router-v2 /usr/local/bin/text-generation-router 166 | # Install launcher 167 | COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher 168 | # Install python server 169 | COPY --from=pyserver /pyserver/build/dist dist 170 | RUN pip install dist/text_generation_server*.tar.gz 171 | 172 | 173 | # TPU compatible image for Inference Endpoints 174 | FROM tpu_base AS inference-endpoint 175 | 176 | COPY text-generation-inference/docker/entrypoint.sh entrypoint.sh 177 | RUN chmod +x entrypoint.sh 178 | ENTRYPOINT ["./entrypoint.sh"] 179 | 180 | FROM tpu_base AS google-cloud-containers 181 | 182 | # Install Google specific components if ENABLE_GOOGLE_FEATURE is set 183 | RUN if [ -n "$ENABLE_GOOGLE_FEATURE" ]; then \ 184 | apt-get update && \ 185 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 186 | ca-certificates \ 187 | curl \ 188 | git && \ 189 | rm -rf /var/lib/apt/lists/* && \ 190 | echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" \ 191 | | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \ 192 | curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \ 193 | | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - && \ 194 | apt-get update -y && \ 195 | apt-get install google-cloud-sdk -y; \ 196 | fi 197 | 198 | # Custom entrypoint for Google 199 | COPY --chmod=775 containers/tgi/tpu/${VERSION}/entrypoint.sh* entrypoint.sh 200 | ENTRYPOINT ["./entrypoint.sh"] 201 | 202 | # TPU compatible image 203 | FROM tpu_base 204 | 205 | ENTRYPOINT ["text-generation-launcher"] 206 | --------------------------------------------------------------------------------