├── .azure └── gpu-tests.yml ├── .codecov.yml ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ ├── documentation.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── stale.yml └── workflows │ ├── ci-checks.yml │ ├── ci-minimal-dependency-check.yml │ ├── ci-parity.yml │ ├── ci-testing.yml │ └── release-pypi.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── _requirements ├── perf.txt └── test.txt ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── setup.py ├── src └── litserve │ ├── __about__.py │ ├── __init__.py │ ├── __main__.py │ ├── api.py │ ├── callbacks │ ├── __init__.py │ ├── base.py │ └── defaults │ │ ├── __init__.py │ │ └── metric_callback.py │ ├── cli.py │ ├── connector.py │ ├── docker_builder.py │ ├── loggers.py │ ├── loops │ ├── __init__.py │ ├── base.py │ ├── continuous_batching_loop.py │ ├── loops.py │ ├── simple_loops.py │ └── streaming_loops.py │ ├── middlewares.py │ ├── python_client.py │ ├── schema │ ├── __init__.py │ └── image.py │ ├── server.py │ ├── specs │ ├── __init__.py │ ├── base.py │ ├── openai.py │ └── openai_embedding.py │ ├── test_examples │ ├── __init__.py │ ├── openai_embedding_spec_example.py │ ├── openai_spec_example.py │ └── simple_example.py │ ├── transport │ ├── __init__.py │ ├── base.py │ ├── factory.py │ ├── process_transport.py │ ├── zmq_queue.py │ └── zmq_transport.py │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── e2e ├── default_api.py ├── default_async_streaming.py ├── default_batched_streaming.py ├── default_batching.py ├── default_openai_embedding_spec.py ├── default_openai_with_batching.py ├── default_openaispec.py ├── default_openaispec_response_format.py ├── default_openaispec_tools.py ├── default_single_streaming.py ├── default_spec.py ├── openai_embedding_with_batching.py └── test_e2e.py ├── minimal_run.py ├── parity_fastapi ├── benchmark.py ├── fastapi-server.py ├── ls-server.py └── main.py ├── perf_test ├── bert │ ├── benchmark.py │ ├── data.py │ ├── run_test.sh │ ├── server.py │ └── utils.py └── stream │ ├── run_test.sh │ └── stream_speed │ ├── benchmark.py │ └── server.py ├── simple_server.py ├── simple_server_diff_port.py ├── test_auth.py ├── test_batch.py ├── test_callbacks.py ├── test_cli.py ├── test_compression.py ├── test_connector.py ├── test_docker_builder.py ├── test_examples.py ├── test_form.py ├── test_lit_server.py ├── test_litapi.py ├── test_logger.py ├── test_logging.py ├── test_loops.py ├── test_middlewares.py ├── test_multiple_endpoints.py ├── test_openai_embedding.py ├── test_pydantic.py ├── test_readme.py ├── test_request_handlers.py ├── test_schema.py ├── test_simple.py ├── test_specs.py ├── test_torch.py ├── test_transport.py ├── test_utils.py └── test_zmq_queue.py /.azure/gpu-tests.yml: -------------------------------------------------------------------------------- 1 | # Create and test a Python package on multiple PyTorch versions. 2 | 3 | trigger: 4 | tags: 5 | include: 6 | - "*" 7 | branches: 8 | include: 9 | - main 10 | - release/* 11 | - refs/tags/* 12 | pr: 13 | - main 14 | - release/* 15 | 16 | jobs: 17 | - job: serve_GPU 18 | # how long to run the job before automatically cancelling 19 | timeoutInMinutes: "20" 20 | # how much time to give 'run always even if cancelled tasks' before stopping them 21 | cancelTimeoutInMinutes: "2" 22 | 23 | pool: "lit-rtx-3090" 24 | 25 | variables: 26 | DEVICES: $( python -c 'name = "$(Agent.Name)" ; gpus = name.split("_")[-1] if "_" in name else "0,1"; print(gpus)' ) 27 | # these two caches assume to run repetitively on the same set of machines 28 | # see: https://github.com/microsoft/azure-pipelines-agent/issues/4113#issuecomment-1439241481 29 | TORCH_HOME: "/var/tmp/torch" 30 | PIP_CACHE_DIR: "/var/tmp/pip" 31 | 32 | container: 33 | image: "pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime" 34 | options: "--gpus=all --shm-size=8g -v /var/tmp:/var/tmp" 35 | 36 | workspace: 37 | clean: all 38 | 39 | steps: 40 | - bash: | 41 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" 42 | CUDA_version=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p') 43 | CUDA_version_mm="${CUDA_version//'.'/''}" 44 | echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$CUDA_version_mm" 45 | echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${CUDA_version_mm}/torch_stable.html" 46 | displayName: "set Env. vars" 47 | 48 | - bash: | 49 | whoami && id 50 | lspci | egrep 'VGA|3D' 51 | whereis nvidia 52 | nvidia-smi 53 | echo $CUDA_VISIBLE_DEVICES 54 | echo $TORCH_URL 55 | python --version 56 | pip --version 57 | pip cache dir 58 | pip list 59 | displayName: "Image info & NVIDIA" 60 | 61 | - bash: | 62 | pip install . -U --prefer-binary \ 63 | -r ./_requirements/test.txt --find-links=${TORCH_URL} 64 | displayName: "Install environment" 65 | 66 | - bash: | 67 | pip list 68 | python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'found GPUs: {mgpu}'" 69 | displayName: "Sanity check" 70 | 71 | - bash: | 72 | pip install -q py-tree 73 | py-tree /var/tmp/torch 74 | displayName: "Show caches" 75 | 76 | - bash: | 77 | coverage run --source litserve -m pytest src tests -v 78 | displayName: "Testing" 79 | 80 | - bash: | 81 | python -m coverage report 82 | python -m coverage xml 83 | python -m codecov --token=$(CODECOV_TOKEN) --name="GPU-coverage" \ 84 | --commit=$(Build.SourceVersion) --flags=gpu,unittest --env=linux,azure 85 | ls -l 86 | displayName: "Statistics" 87 | 88 | - bash: | 89 | pip install torch torchvision -U -q --find-links=${TORCH_URL} -r _requirements/perf.txt 90 | export PYTHONPATH=$PWD && python tests/parity_fastapi/main.py 91 | displayName: "Run FastAPI parity tests" 92 | 93 | - bash: | 94 | pip install gpustat wget -U -q 95 | bash tests/perf_test/bert/run_test.sh 96 | displayName: "Run GPU perf test" 97 | -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | # see https://docs.codecov.io/docs/codecov-yaml 2 | # Validation check: 3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate 4 | 5 | # https://docs.codecov.io/docs/codecovyml-reference 6 | codecov: 7 | bot: "codecov-io" 8 | strict_yaml_branch: "yaml-config" 9 | require_ci_to_pass: yes 10 | notify: 11 | # after_n_builds: 2 12 | wait_for_ci: yes 13 | 14 | coverage: 15 | precision: 0 # 2 = xx.xx%, 0 = xx% 16 | round: nearest # how coverage is rounded: down/up/nearest 17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 18 | status: 19 | # https://codecov.readme.io/v1.0/docs/commit-status 20 | project: 21 | default: 22 | informational: true 23 | target: 99% # specify the target coverage for each commit status 24 | threshold: 30% # allow this little decrease on project 25 | # https://github.com/codecov/support/wiki/Filtering-Branches 26 | # branches: main 27 | if_ci_failed: error 28 | # https://github.com/codecov/support/wiki/Patch-Status 29 | patch: 30 | default: 31 | informational: true 32 | target: 50% # specify the target "X%" coverage to hit 33 | # threshold: 50% # allow this much decrease on patch 34 | changes: false 35 | 36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations 37 | github_checks: 38 | annotations: false 39 | 40 | parsers: 41 | gcov: 42 | branch_detection: 43 | conditional: true 44 | loop: true 45 | macro: false 46 | method: false 47 | javascript: 48 | enable_partials: false 49 | 50 | comment: 51 | layout: header, diff 52 | require_changes: false 53 | behavior: default # update if exists else create new 54 | # branches: * 55 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | 3 | # These owners will be the default owners for everything in the repo. Unless a later match takes precedence, 4 | # @global-owner1 and @global-owner2 will be requested for review when someone opens a pull request. 5 | * @lantiga @aniketmaurya @ethanwharris @borda @justusschock @tchaton @ali-alshaar7 @k223kim @KaelanDt 6 | 7 | # CI/CD and configs 8 | /.github/ @borda 9 | *.yaml @borda 10 | *.yml @borda 11 | 12 | # devel requirements 13 | /_requirements/ @borda @lantiga 14 | 15 | /README.md @williamfalcon @lantiga 16 | /requirements.txt @williamfalcon 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🐛 Bug 10 | 11 | 12 | 13 | ### To Reproduce 14 | 15 | Attach a [Lightning Studio](https://lightning.ai/studios) which is fully reproducible (code, dependencies, environment, etc...) to reproduce this: 16 | 17 | 1. Create a [Studio](https://lightning.ai/studios). 18 | 2. Reproduce the issue in the Studio. 19 | 3. [Publish the Studio](https://lightning.ai/docs/overview/studios/publishing#how-to-publish). 20 | 4. Paste the Studio link here. 21 | 22 | 23 | 24 | #### Code sample 25 | 26 | 28 | 29 | ### Expected behavior 30 | 31 | 32 | 33 | ### Environment 34 | If you published a Studio with your bug report, we can automatically get this information. Otherwise, please describe: 35 | 36 | - PyTorch/Jax/Tensorflow Version (e.g., 1.0): 37 | - OS (e.g., Linux): 38 | - How you installed PyTorch (`conda`, `pip`, source): 39 | - Build command you used (if compiling from source): 40 | - Python version: 41 | - CUDA/cuDNN version: 42 | - GPU models and configuration: 43 | - Any other relevant information: 44 | 45 | ### Additional context 46 | 47 | 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: ❓ Ask a Question 4 | url: https://www.reddit.com/r/lightningAI/ 5 | about: Ask and answer Lightning related questions. 6 | - name: 💬 Chat with us 7 | url: https://discord.gg/VptPCZkGNa 8 | about: Live chat with experts, engineers, and users in our Discord community. 9 | - name: 📖 Read the documentation 10 | url: https://lightning.ai/litserve/ 11 | about: Please consult the documentation before opening any issues! 12 | - name: 🙋 Get help with Lightning Studio 13 | url: https://lightning.ai 14 | about: Contact the Lightning.ai sales team. 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Typos and doc fixes 3 | about: Typos and doc fixes 4 | title: '' 5 | labels: documentation 6 | assignees: '' 7 | --- 8 | 9 | ## 📚 Documentation 10 | 11 | For typos and doc fixes, please go ahead and: 12 | 13 | - For a simple typo or fix, please send directly a PR (no need to create an issue) 14 | - If you are not sure about the proper solution, please describe here your finding... 15 | 16 | Thanks! 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | 22 | 23 | ---- 24 | 25 | ## 🚀 Feature 26 | 27 | 28 | 29 | ### Motivation 30 | 31 | 36 | 37 | ### Pitch 38 | 39 | 40 | 41 | ### Alternatives 42 | 43 | 44 | 45 | ### Additional context 46 | 47 | 48 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | Fixes # (issue). 4 | 5 | 6 |
7 | Before submitting 8 | 9 | - [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements) 10 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section? 11 | - [ ] Did you make sure to update the docs? 12 | - [ ] Did you write any new necessary tests? 13 | 14 |
15 | 16 | 29 | 30 | 31 | 32 | ## PR review 33 | 34 | Anyone in the community is free to review the PR once the tests have passed. 35 | If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged. 36 | 37 | ## Did you have fun? 38 | 39 | Make sure you had fun coding 🙃 40 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/stale 2 | 3 | # Number of days of inactivity before an issue becomes stale 4 | daysUntilStale: 60 5 | # Number of days of inactivity before a stale issue is closed 6 | daysUntilClose: 14 7 | # Issues with these labels will never be considered stale 8 | exemptLabels: 9 | - pinned 10 | - security 11 | # Label to use when marking an issue as stale 12 | staleLabel: won't fix 13 | # Comment to post when marking an issue as stale. Set to `false` to disable 14 | markComment: > 15 | This issue has been automatically marked as stale because it has not had 16 | recent activity. It will be closed if no further activity occurs. Thank you 17 | for your contributions. 18 | # Comment to post when closing a stale issue. Set to `false` to disable 19 | closeComment: false 20 | 21 | # Set to true to ignore issues in a project (defaults to false) 22 | exemptProjects: true 23 | # Set to true to ignore issues in a milestone (defaults to false) 24 | exemptMilestones: true 25 | # Set to true to ignore issues with an assignee (defaults to false) 26 | exemptAssignees: true 27 | -------------------------------------------------------------------------------- /.github/workflows/ci-checks.yml: -------------------------------------------------------------------------------- 1 | name: General checks 2 | 3 | on: 4 | push: 5 | branches: [main, "release/*"] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }} 12 | 13 | jobs: 14 | # check-typing: 15 | # uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@main 16 | # with: 17 | # actions-ref: main 18 | 19 | check-schema: 20 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main 21 | with: 22 | azure-dir: "" 23 | 24 | check-package: 25 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main 26 | with: 27 | actions-ref: main 28 | import-name: "litserve" 29 | artifact-name: dist-packages-${{ github.sha }} 30 | testing-matrix: | 31 | { 32 | "os": ["ubuntu-latest", "macos-latest", "windows-latest"], 33 | "python-version": ["3.10"], 34 | } 35 | 36 | # check-docs: 37 | # uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@main 38 | # with: 39 | # requirements-file: "_requirements/docs.txt" 40 | -------------------------------------------------------------------------------- /.github/workflows/ci-minimal-dependency-check.yml: -------------------------------------------------------------------------------- 1 | name: Minimal dependency check 2 | 3 | on: 4 | push: 5 | branches: [main, "release/*"] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | defaults: 10 | run: 11 | shell: bash 12 | 13 | jobs: 14 | minimal-test: 15 | runs-on: ubuntu-latest 16 | 17 | timeout-minutes: 30 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.11" 25 | 26 | - name: Install LitServe 27 | run: | 28 | pip --version 29 | pip install . psutil -U -q 30 | pip list 31 | 32 | - name: Tests 33 | run: python tests/minimal_run.py 34 | -------------------------------------------------------------------------------- /.github/workflows/ci-parity.yml: -------------------------------------------------------------------------------- 1 | name: Run parity tests 2 | 3 | on: 4 | push: 5 | branches: [main, "release/*"] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | defaults: 10 | run: 11 | shell: bash 12 | 13 | jobs: 14 | pytester: 15 | runs-on: ubuntu-latest 16 | 17 | timeout-minutes: 30 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.11" 25 | 26 | - name: Install LitServe 27 | run: | 28 | pip --version 29 | pip install . torchvision jsonargparse uvloop tenacity -U -q -r _requirements/test.txt -U -q 30 | pip list 31 | 32 | - name: Parity test 33 | run: export PYTHONPATH=$PWD && python tests/parity_fastapi/main.py 34 | 35 | - name: Streaming speed test 36 | run: bash tests/perf_test/stream/run_test.sh 37 | -------------------------------------------------------------------------------- /.github/workflows/ci-testing.yml: -------------------------------------------------------------------------------- 1 | name: CI testing 2 | 3 | on: 4 | push: 5 | branches: [main, "release/*"] 6 | pull_request: 7 | branches: [main, "release/*"] 8 | 9 | defaults: 10 | run: 11 | shell: bash 12 | 13 | jobs: 14 | pytester: 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: ["ubuntu-latest"] 20 | python-version: ["3.9", "3.10", "3.11"] # todo, "3.12" 21 | include: 22 | - { os: "macos-latest", python-version: "3.12" } 23 | - { os: "windows-latest", python-version: "3.11" } 24 | - { os: "ubuntu-22.04", python-version: "3.9", requires: "oldest" } 25 | 26 | timeout-minutes: 35 27 | env: 28 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" 29 | 30 | steps: 31 | - uses: actions/checkout@v4 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v5 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | cache: "pip" 37 | 38 | - name: Set min. dependencies 39 | if: matrix.requires == 'oldest' 40 | run: | 41 | pip install 'lightning-utilities[cli]' 42 | python -m lightning_utilities.cli requirements set-oldest --req_files='["requirements.txt", "_requirements/test.txt"]' 43 | 44 | - name: Install package & dependencies 45 | run: | 46 | pip --version 47 | pip install -e '.[test]' -U -q --find-links $TORCH_URL 48 | pip list 49 | 50 | - name: Tests 51 | timeout-minutes: 10 52 | run: | 53 | python -m pytest --cov=litserve src/ tests/ -v -s 54 | 55 | - name: Statistics 56 | run: | 57 | coverage report 58 | coverage xml 59 | 60 | - name: Upload coverage to Codecov 61 | uses: codecov/codecov-action@v4 62 | with: 63 | token: ${{ secrets.CODECOV_TOKEN }} 64 | file: ./coverage.xml 65 | flags: unittests 66 | env_vars: OS,PYTHON 67 | name: codecov-umbrella 68 | fail_ci_if_error: false 69 | 70 | tests-guardian: 71 | runs-on: ubuntu-latest 72 | needs: pytester 73 | if: always() 74 | steps: 75 | - run: echo "${{ needs.pytester.result }}" 76 | - name: failing... 77 | if: needs.pytester.result == 'failure' 78 | run: exit 1 79 | - name: cancelled or skipped... 80 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result) 81 | timeout-minutes: 1 82 | run: sleep 90 83 | -------------------------------------------------------------------------------- /.github/workflows/release-pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the main branch 5 | push: 6 | branches: [main] 7 | release: 8 | types: [published] 9 | pull_request: 10 | branches: [main] 11 | paths: 12 | - ".github/workflows/release-pypi.yml" 13 | 14 | # based on https://github.com/pypa/gh-action-pypi-publish 15 | 16 | jobs: 17 | build-package: 18 | runs-on: ubuntu-latest 19 | timeout-minutes: 10 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: actions/setup-python@v5 23 | with: 24 | python-version: "3.x" 25 | - name: Pull reusable 🤖 actions️ 26 | uses: actions/checkout@v4 27 | with: 28 | ref: main 29 | path: .cicd 30 | repository: Lightning-AI/utilities 31 | 32 | - name: Prepare build env. 33 | run: pip install -r ./.cicd/requirements/gha-package.txt 34 | - name: Create 📦 package 35 | uses: ./.cicd/.github/actions/pkg-create 36 | - uses: actions/upload-artifact@v4 37 | with: 38 | name: pypi-packages-${{ github.sha }} 39 | path: dist 40 | 41 | publish-package: 42 | needs: build-package 43 | runs-on: ubuntu-latest 44 | timeout-minutes: 5 45 | steps: 46 | - uses: actions/download-artifact@v4 47 | with: 48 | name: pypi-packages-${{ github.sha }} 49 | path: dist 50 | - run: ls -lh dist/ 51 | 52 | - name: Publish distribution 📦 to PyPI 53 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 54 | uses: pypa/gh-action-pypi-publish@v1.12.3 55 | with: 56 | user: __token__ 57 | password: ${{ secrets.pypi_password }} 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | *.egg-info/ 29 | src.egg-info/ 30 | 31 | # Lightning /research 32 | test_tube_exp/ 33 | tests/tests_tt_dir/ 34 | tests/save_dir 35 | default/ 36 | data/ 37 | test_tube_logs/ 38 | test_tube_data/ 39 | datasets/ 40 | model_weights/ 41 | tests/save_dir 42 | tests/tests_tt_dir/ 43 | processed/ 44 | raw/ 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | 122 | # mypy 123 | .mypy_cache/ 124 | 125 | # IDEs 126 | .idea 127 | .vscode 128 | 129 | # seed project 130 | lightning_logs/ 131 | MNIST 132 | .DS_Store 133 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions" 7 | autoupdate_schedule: "monthly" 8 | # submodules: true 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v5.0.0 13 | hooks: 14 | - id: end-of-file-fixer 15 | - id: trailing-whitespace 16 | exclude: '.*\.md$' 17 | - id: check-case-conflict 18 | - id: check-yaml 19 | - id: check-toml 20 | - id: check-json 21 | - id: check-added-large-files 22 | - id: check-docstring-first 23 | - id: detect-private-key 24 | 25 | - repo: https://github.com/codespell-project/codespell 26 | rev: v2.4.1 27 | hooks: 28 | - id: codespell 29 | additional_dependencies: [tomli] 30 | #args: ["--write-changes"] 31 | 32 | - repo: https://github.com/pre-commit/mirrors-prettier 33 | rev: v3.1.0 34 | hooks: 35 | - id: prettier 36 | files: \.(json|yml|yaml|toml) 37 | # https://prettier.io/docs/en/options.html#print-width 38 | args: ["--print-width=120"] 39 | 40 | - repo: https://github.com/PyCQA/docformatter 41 | rev: v1.7.7 42 | hooks: 43 | - id: docformatter 44 | additional_dependencies: [tomli] 45 | args: ["--in-place"] 46 | 47 | - repo: https://github.com/astral-sh/ruff-pre-commit 48 | rev: v0.11.12 49 | hooks: 50 | - id: ruff-format 51 | args: ["--preview"] 52 | - id: ruff 53 | args: ["--fix"] 54 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html 2 | graft wheelhouse 3 | 4 | recursive-exclude __pycache__ *.py[cod] *.orig 5 | 6 | # Include the README and CHANGELOG 7 | include *.md 8 | recursive-include src *.md 9 | 10 | # Include the license file 11 | include LICENSE 12 | 13 | # Exclude build configs 14 | exclude *.sh 15 | exclude *.toml 16 | exclude *.svg 17 | exclude *.yml 18 | exclude *.yaml 19 | 20 | # exclude tests from package 21 | recursive-exclude tests * 22 | recursive-exclude site * 23 | exclude tests 24 | 25 | # Exclude the documentation files 26 | recursive-exclude docs * 27 | exclude docs 28 | 29 | # Include the Requirements 30 | include requirements.txt 31 | recursive-include _requirements *.tx;t 32 | 33 | # Exclude Makefile 34 | exclude Makefile 35 | 36 | prune .git 37 | prune .github 38 | prune temp* 39 | prune test* 40 | -------------------------------------------------------------------------------- /_requirements/perf.txt: -------------------------------------------------------------------------------- 1 | uvloop 2 | tenacity 3 | jsonargparse 4 | -------------------------------------------------------------------------------- /_requirements/test.txt: -------------------------------------------------------------------------------- 1 | httpx>=0.27.0 2 | coverage[toml] >=7.5.3 3 | pytest >=8.0 4 | pytest-cov 5 | mypy ==1.11.2 6 | pytest-asyncio 7 | asgi-lifespan 8 | python-multipart 9 | psutil 10 | requests 11 | lightning >2.0.0 12 | torch >2.0.0 13 | transformers 14 | openai>=1.12.0 15 | pillow 16 | numpy <2.0 17 | pytest-retry>=1.6.3 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_file = "LICENSE" 3 | description-file = "README.md" 4 | 5 | [build-system] 6 | requires = [ 7 | "setuptools", 8 | "wheel", 9 | ] 10 | 11 | 12 | [tool.check-manifest] 13 | ignore = [ 14 | "*.yml", 15 | ".github", 16 | ".github/*" 17 | ] 18 | 19 | 20 | [tool.pytest.ini_options] 21 | norecursedirs = [ 22 | ".git", 23 | ".github", 24 | "dist", 25 | "build", 26 | "docs", 27 | ] 28 | addopts = [ 29 | "--strict-markers", 30 | "--doctest-modules", 31 | "--color=yes", 32 | "--disable-pytest-warnings", 33 | ] 34 | filterwarnings = [ 35 | "error::FutureWarning", 36 | ] 37 | xfail_strict = true 38 | junit_duration_report = "call" 39 | 40 | [tool.coverage.report] 41 | exclude_lines = [ 42 | "pragma: no cover", 43 | "pass", 44 | ] 45 | 46 | [tool.codespell] 47 | #skip = '*.py' 48 | quiet-level = 3 49 | # comma separated list of words; waiting for: 50 | # https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603 51 | # also adding links until they ignored by its: nature 52 | # https://github.com/codespell-project/codespell/issues/2243#issuecomment-1732019960 53 | #ignore-words-list = "" 54 | 55 | [tool.docformatter] 56 | recursive = true 57 | wrap-summaries = 120 58 | wrap-descriptions = 120 59 | blank = true 60 | 61 | 62 | #[tool.mypy] 63 | #files = [ 64 | # "src", 65 | #] 66 | #install_types = true 67 | #non_interactive = true 68 | #disallow_untyped_defs = true 69 | #ignore_missing_imports = true 70 | #show_error_codes = true 71 | #warn_redundant_casts = true 72 | #warn_unused_configs = true 73 | #warn_unused_ignores = true 74 | #allow_redefinition = true 75 | ## disable this rule as the Trainer attributes are defined in the connectors, not in its __init__ 76 | #disable_error_code = "attr-defined" 77 | ## style choices 78 | #warn_no_return = false 79 | 80 | 81 | [tool.ruff] 82 | line-length = 120 83 | target-version = "py38" 84 | # Enable Pyflakes `E` and `F` codes by default. 85 | lint.select = [ 86 | "E", "W", # see: https://pypi.org/project/pycodestyle 87 | "F", # see: https://pypi.org/project/pyflakes 88 | "N", # see: https://pypi.org/project/pep8-naming 89 | # "D", # see: https://pypi.org/project/pydocstyle 90 | "I", # implementation for isort 91 | "UP", # implementation for pyupgrade 92 | "RUF100", # implementation for yesqa 93 | ] 94 | lint.extend-select = [ 95 | "C4", # see: https://pypi.org/project/flake8-comprehensions 96 | "PT", # see: https://pypi.org/project/flake8-pytest-style 97 | "RET", # see: https://pypi.org/project/flake8-return 98 | "SIM", # see: https://pypi.org/project/flake8-simplify 99 | ] 100 | lint.ignore = [ 101 | "E731", # Do not assign a lambda expression, use a def 102 | ] 103 | # Exclude a variety of commonly ignored directories. 104 | exclude = [ 105 | ".eggs", 106 | ".git", 107 | ".mypy_cache", 108 | ".ruff_cache", 109 | "__pypackages__", 110 | "_build", 111 | "build", 112 | "dist", 113 | "docs" 114 | ] 115 | 116 | [tool.ruff.lint.per-file-ignores] 117 | "setup.py" = ["D100", "SIM115"] 118 | "__about__.py" = ["D100"] 119 | "__init__.py" = ["D100"] 120 | 121 | [tool.ruff.lint.pydocstyle] 122 | # Use Google-style docstrings. 123 | convention = "google" 124 | 125 | #[tool.ruff.pycodestyle] 126 | #ignore-overlong-task-comments = true 127 | 128 | [tool.ruff.lint.mccabe] 129 | # Unlike Flake8, default to a complexity level of 10. 130 | max-complexity = 10 131 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | python_files = test_*.py 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi >=0.100 2 | uvicorn[standard] >=0.29.0 3 | pyzmq >=22.0.0 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright The Lightning AI team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import glob 16 | import os 17 | from importlib.util import module_from_spec, spec_from_file_location 18 | from pathlib import Path 19 | 20 | from pkg_resources import parse_requirements 21 | from setuptools import find_packages, setup 22 | 23 | _PATH_ROOT = os.path.dirname(__file__) 24 | _PATH_SOURCE = os.path.join(_PATH_ROOT, "src") 25 | _PATH_REQUIRES = os.path.join(_PATH_ROOT, "_requirements") 26 | 27 | 28 | def _load_py_module(fname, pkg="litserve"): 29 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname)) 30 | py = module_from_spec(spec) 31 | spec.loader.exec_module(py) 32 | return py 33 | 34 | 35 | def _load_requirements(path_dir: str = _PATH_ROOT, file_name: str = "requirements.txt") -> list: 36 | reqs = parse_requirements(open(os.path.join(path_dir, file_name)).readlines()) 37 | return list(map(str, reqs)) 38 | 39 | 40 | about = _load_py_module("__about__.py") 41 | with open(os.path.join(_PATH_ROOT, "README.md"), encoding="utf-8") as fopen: 42 | readme = fopen.read() 43 | 44 | 45 | def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple = ("devel.txt", "docs.txt")) -> dict: 46 | # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras 47 | # Define package extras. These are only installed if you specify them. 48 | # From remote, use like `pip install pytorch-lightning[dev, docs]` 49 | # From local copy of repo, use like `pip install ".[dev, docs]"` 50 | req_files = [Path(p) for p in glob.glob(os.path.join(requirements_dir, "*.txt"))] 51 | extras = { 52 | p.stem: _load_requirements(file_name=p.name, path_dir=str(p.parent)) 53 | for p in req_files 54 | if p.name not in skip_files 55 | } 56 | # todo: eventually add some custom aggregations such as `develop` 57 | extras = {name: sorted(set(reqs)) for name, reqs in extras.items()} 58 | print("The extras are: ", extras) 59 | return extras 60 | 61 | 62 | # https://packaging.python.org/discussions/install-requires-vs-requirements / 63 | # keep the meta-data here for simplicity in reading this file... it's not obvious 64 | # what happens and to non-engineers they won't know to look in init ... 65 | # the goal of the project is simplicity for researchers, don't want to add too much 66 | # engineer specific practices 67 | setup( 68 | name="litserve", 69 | version=about.__version__, 70 | description=about.__docs__, 71 | author=about.__author__, 72 | author_email=about.__author_email__, 73 | url=about.__homepage__, 74 | download_url="https://github.com/Lightning-AI/litserve", 75 | license=about.__license__, 76 | packages=find_packages(where="src"), 77 | package_dir={"": "src"}, 78 | long_description=readme, 79 | long_description_content_type="text/markdown", 80 | include_package_data=True, 81 | zip_safe=False, 82 | keywords=["deep learning", "pytorch", "AI"], 83 | python_requires=">=3.8", 84 | setup_requires=["wheel"], 85 | install_requires=_load_requirements(), 86 | extras_require=_prepare_extras(), 87 | project_urls={ 88 | "Bug Tracker": "https://github.com/Lightning-AI/litserve/issues", 89 | "Documentation": "https://lightning-ai.github.io/litserve/", 90 | "Source Code": "https://github.com/Lightning-AI/litserve", 91 | }, 92 | classifiers=[ 93 | "Environment :: Console", 94 | "Natural Language :: English", 95 | # How mature is this project? Common values are 96 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable 97 | "Development Status :: 3 - Alpha", 98 | # Indicate who your project is intended for 99 | "Intended Audience :: Developers", 100 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 101 | "Topic :: Scientific/Engineering :: Information Analysis", 102 | # Pick your license as you wish 103 | "License :: OSI Approved :: Apache Software License", 104 | "Operating System :: OS Independent", 105 | # Specify the Python versions you support here. In particular, ensure 106 | # that you indicate whether you support Python 2, Python 3 or both. 107 | "Programming Language :: Python :: 3", 108 | "Programming Language :: Python :: 3.8", 109 | "Programming Language :: Python :: 3.9", 110 | "Programming Language :: Python :: 3.10", 111 | "Programming Language :: Python :: 3.11", 112 | ], 113 | entry_points={ 114 | "console_scripts": ["litserve=litserve.__main__:main", "lightning=litserve.cli:main"], 115 | }, 116 | ) 117 | -------------------------------------------------------------------------------- /src/litserve/__about__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | __version__ = "0.2.11" 15 | __author__ = "Lightning-AI et al." 16 | __author_email__ = "community@lightning.ai" 17 | __license__ = "Apache-2.0" 18 | __copyright__ = f"Copyright (c) 2024, {__author__}." 19 | __homepage__ = "https://github.com/Lightning-AI/litserve" 20 | __docs__ = "Lightweight AI server." 21 | 22 | __all__ = [ 23 | "__author__", 24 | "__author_email__", 25 | "__copyright__", 26 | "__docs__", 27 | "__homepage__", 28 | "__license__", 29 | "__version__", 30 | ] 31 | -------------------------------------------------------------------------------- /src/litserve/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from litserve import test_examples 15 | from litserve.__about__ import * # noqa: F403 16 | from litserve.api import LitAPI 17 | from litserve.callbacks import Callback 18 | from litserve.loggers import Logger 19 | from litserve.server import LitServer, Request, Response 20 | from litserve.specs import OpenAIEmbeddingSpec, OpenAISpec 21 | from litserve.utils import configure_logging, set_trace, set_trace_if_debug 22 | 23 | configure_logging() 24 | 25 | __all__ = [ 26 | "LitAPI", 27 | "LitServer", 28 | "Request", 29 | "Response", 30 | "OpenAISpec", 31 | "OpenAIEmbeddingSpec", 32 | "test_examples", 33 | "Callback", 34 | "Logger", 35 | "set_trace", 36 | "set_trace_if_debug", 37 | ] 38 | -------------------------------------------------------------------------------- /src/litserve/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawTextHelpFormatter 15 | 16 | from litserve.docker_builder import dockerize 17 | 18 | 19 | class LitFormatter(ArgumentDefaultsHelpFormatter, RawTextHelpFormatter): ... 20 | 21 | 22 | def main(): 23 | parser = ArgumentParser(description="CLI for LitServe", formatter_class=LitFormatter) 24 | subparsers = parser.add_subparsers( 25 | dest="command", 26 | title="Commands", 27 | ) 28 | 29 | # dockerize sub-command 30 | dockerize_parser = subparsers.add_parser( 31 | "dockerize", 32 | help="Generate a Dockerfile for the given server code.", 33 | description="Generate a Dockerfile for the given server code.\nExample usage:" 34 | " litserve dockerize server.py --port 8000 --gpu", 35 | formatter_class=LitFormatter, 36 | ) 37 | dockerize_parser.add_argument( 38 | "server_filename", 39 | type=str, 40 | help="The path to the server file. Example: server.py or app.py.", 41 | ) 42 | dockerize_parser.add_argument( 43 | "--port", 44 | type=int, 45 | default=8000, 46 | help="The port to expose in the Docker container.", 47 | ) 48 | dockerize_parser.add_argument( 49 | "--gpu", 50 | default=False, 51 | action="store_true", 52 | help="Whether to use a GPU-enabled Docker image.", 53 | ) 54 | dockerize_parser.set_defaults(func=lambda args: dockerize(args.server_filename, args.port, args.gpu)) 55 | args = parser.parse_args() 56 | 57 | if hasattr(args, "func"): 58 | args.func(args) 59 | else: 60 | parser.print_help() 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /src/litserve/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from litserve.callbacks.base import Callback, CallbackRunner, EventTypes, NoopCallback 2 | 3 | __all__ = ["Callback", "CallbackRunner", "EventTypes", "NoopCallback"] 4 | -------------------------------------------------------------------------------- /src/litserve/callbacks/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC 3 | from enum import Enum 4 | from typing import List, Union 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class EventTypes(Enum): 10 | BEFORE_SETUP = "on_before_setup" 11 | AFTER_SETUP = "on_after_setup" 12 | BEFORE_DECODE_REQUEST = "on_before_decode_request" 13 | AFTER_DECODE_REQUEST = "on_after_decode_request" 14 | BEFORE_ENCODE_RESPONSE = "on_before_encode_response" 15 | AFTER_ENCODE_RESPONSE = "on_after_encode_response" 16 | BEFORE_PREDICT = "on_before_predict" 17 | AFTER_PREDICT = "on_after_predict" 18 | ON_SERVER_START = "on_server_start" 19 | ON_SERVER_END = "on_server_end" 20 | ON_REQUEST = "on_request" 21 | ON_RESPONSE = "on_response" 22 | 23 | 24 | class Callback(ABC): 25 | def on_before_setup(self, *args, **kwargs): 26 | """Called before setup is started.""" 27 | 28 | def on_after_setup(self, *args, **kwargs): 29 | """Called after setup is completed.""" 30 | 31 | def on_before_decode_request(self, *args, **kwargs): 32 | """Called before request decoding is started.""" 33 | 34 | def on_after_decode_request(self, *args, **kwargs): 35 | """Called after request decoding is completed.""" 36 | 37 | def on_before_encode_response(self, *args, **kwargs): 38 | """Called before response encoding is started.""" 39 | 40 | def on_after_encode_response(self, *args, **kwargs): 41 | """Called after response encoding is completed.""" 42 | 43 | def on_before_predict(self, *args, **kwargs): 44 | """Called before prediction is started.""" 45 | 46 | def on_after_predict(self, *args, **kwargs): 47 | """Called after prediction is completed.""" 48 | 49 | def on_server_start(self, *args, **kwargs): 50 | """Called before server starts.""" 51 | 52 | def on_server_end(self, *args, **kwargs): 53 | """Called when server terminates.""" 54 | 55 | def on_request(self, *args, **kwargs): 56 | """Called when request enters the endpoint function.""" 57 | 58 | def on_response(self, *args, **kwargs): 59 | """Called when response is generated from the worker and ready to return to the client.""" 60 | 61 | # Adding a new hook? Register it with the EventTypes dataclass too, 62 | 63 | 64 | class CallbackRunner: 65 | def __init__(self, callbacks: Union[Callback, List[Callback]] = None): 66 | self._callbacks = [] 67 | if callbacks: 68 | self._add_callbacks(callbacks) 69 | 70 | def _add_callbacks(self, callbacks: Union[Callback, List[Callback]]): 71 | if not isinstance(callbacks, list): 72 | callbacks = [callbacks] 73 | for callback in callbacks: 74 | if not isinstance(callback, Callback): 75 | raise ValueError(f"Invalid callback type: {callback}") 76 | self._callbacks.extend(callbacks) 77 | 78 | def trigger_event(self, event_name, *args, **kwargs): 79 | """Triggers an event, invoking all registered callbacks for that event.""" 80 | for callback in self._callbacks: 81 | try: 82 | getattr(callback, event_name)(*args, **kwargs) 83 | except Exception: 84 | # Handle exceptions to prevent one callback from disrupting others 85 | logger.exception(f"Error in callback '{callback}' during event '{event_name}'") 86 | 87 | 88 | class NoopCallback(Callback): 89 | """This callback does nothing.""" 90 | -------------------------------------------------------------------------------- /src/litserve/callbacks/defaults/__init__.py: -------------------------------------------------------------------------------- 1 | from litserve.callbacks.defaults.metric_callback import PredictionTimeLogger 2 | 3 | __all__ = ["PredictionTimeLogger"] 4 | -------------------------------------------------------------------------------- /src/litserve/callbacks/defaults/metric_callback.py: -------------------------------------------------------------------------------- 1 | import time 2 | import typing 3 | 4 | from litserve.callbacks.base import Callback 5 | 6 | if typing.TYPE_CHECKING: 7 | from litserve import LitAPI 8 | 9 | 10 | class PredictionTimeLogger(Callback): 11 | def on_before_predict(self, lit_api: "LitAPI"): 12 | self._start_time = time.perf_counter() 13 | 14 | def on_after_predict(self, lit_api: "LitAPI"): 15 | elapsed = time.perf_counter() - self._start_time 16 | print(f"Prediction took {elapsed:.2f} seconds", flush=True) 17 | 18 | 19 | class RequestTracker(Callback): 20 | def on_request(self, active_requests: int, **kwargs): 21 | print(f"Active requests: {active_requests}", flush=True) 22 | -------------------------------------------------------------------------------- /src/litserve/cli.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import subprocess 3 | import sys 4 | 5 | 6 | def _ensure_lightning_installed(): 7 | if not importlib.util.find_spec("lightning_sdk"): 8 | print("Lightning CLI not found. Installing...") 9 | subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "lightning-sdk"]) 10 | 11 | 12 | def main(): 13 | _ensure_lightning_installed() 14 | 15 | try: 16 | # Import the correct entry point for lightning_sdk 17 | from lightning_sdk.cli.entrypoint import main_cli 18 | 19 | # Call the lightning CLI's main function directly with our arguments 20 | # This bypasses the command-line entry point completely 21 | sys.argv[0] = "lightning" # Make it think it was called as "lightning" 22 | main_cli() 23 | except ImportError as e: 24 | # If there's an issue importing or finding the right module 25 | print(f"Error importing lightning_sdk CLI: {e}") 26 | print("Please ensure `lightning-sdk` is installed correctly.") 27 | sys.exit(1) 28 | -------------------------------------------------------------------------------- /src/litserve/connector.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import platform 16 | import subprocess 17 | import sys 18 | from functools import lru_cache 19 | from typing import List, Optional, Union 20 | 21 | 22 | class _Connector: 23 | def __init__(self, accelerator: str = "auto", devices: Union[List[int], int, str] = "auto"): 24 | accelerator = self._sanitize_accelerator(accelerator) 25 | if accelerator in ("cpu", "cuda", "mps"): 26 | self._accelerator = accelerator 27 | elif accelerator == "auto": 28 | self._accelerator = self._choose_auto_accelerator() 29 | elif accelerator == "gpu": 30 | self._accelerator = self._choose_gpu_accelerator_backend() 31 | 32 | if devices == "auto": 33 | self._devices = self._accelerator_device_count() 34 | else: 35 | self._devices = devices 36 | 37 | self.check_devices_and_accelerators() 38 | 39 | def check_devices_and_accelerators(self): 40 | """Check if the devices are in a valid fomra and raise an error if they are not.""" 41 | if self._accelerator in ("cuda", "mps"): 42 | if not isinstance(self._devices, int) and not ( 43 | isinstance(self._devices, list) and all(isinstance(device, int) for device in self._devices) 44 | ): 45 | raise ValueError( 46 | "devices must be an integer or a list of integers when using 'cuda' or 'mps', " 47 | f"instead got {self._devices}" 48 | ) 49 | elif self._accelerator != "cpu": 50 | raise ValueError(f"accelerator must be one of (cuda, mps, cpu), instead got {self._accelerator}") 51 | 52 | @property 53 | def accelerator(self): 54 | return self._accelerator 55 | 56 | @property 57 | def devices(self): 58 | return self._devices 59 | 60 | @staticmethod 61 | def _sanitize_accelerator(accelerator: Optional[str]): 62 | if isinstance(accelerator, str): 63 | accelerator = accelerator.lower() 64 | 65 | if accelerator not in ["auto", "cpu", "mps", "cuda", "gpu", None]: 66 | raise ValueError(f"accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'. Found: {accelerator}") 67 | 68 | if accelerator is None: 69 | return "auto" 70 | return accelerator 71 | 72 | def _choose_auto_accelerator(self): 73 | gpu_backend = self._choose_gpu_accelerator_backend() 74 | if "torch" in sys.modules and gpu_backend: 75 | return gpu_backend 76 | return "cpu" 77 | 78 | def _accelerator_device_count(self) -> int: 79 | if self._accelerator == "cuda": 80 | return check_cuda_with_nvidia_smi() 81 | return 1 82 | 83 | @staticmethod 84 | def _choose_gpu_accelerator_backend(): 85 | if check_cuda_with_nvidia_smi() > 0: 86 | return "cuda" 87 | 88 | try: 89 | import torch 90 | 91 | if torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"): 92 | return "mps" 93 | except ImportError: 94 | return None 95 | 96 | return None 97 | 98 | 99 | @lru_cache(maxsize=1) 100 | def check_cuda_with_nvidia_smi() -> int: 101 | """Checks if CUDA is installed using the `nvidia-smi` command-line tool. 102 | 103 | Returns count of visible devices. 104 | 105 | """ 106 | try: 107 | nvidia_smi_output = subprocess.check_output(["nvidia-smi", "-L"]).decode("utf-8").strip() 108 | devices = [el for el in nvidia_smi_output.split("\n") if el.startswith("GPU")] 109 | devices = [el.split(":")[0].split()[1] for el in devices] 110 | visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") 111 | if visible_devices: 112 | # we need check the intersection of devices and visible devices, since 113 | # using CUDA_VISIBLE_DEVICES=0,25 on a 4-GPU machine will yield 114 | # torch.cuda.device_count() == 1 115 | devices = [el for el in devices if el in visible_devices.split(",")] 116 | return len(devices) 117 | except (subprocess.CalledProcessError, FileNotFoundError): 118 | return 0 119 | -------------------------------------------------------------------------------- /src/litserve/docker_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | import os 16 | import warnings 17 | from pathlib import Path 18 | 19 | import litserve as ls 20 | 21 | logger = logging.getLogger(__name__) 22 | logger.setLevel(logging.INFO) 23 | logger.propagate = False 24 | console_handler = logging.StreamHandler() 25 | console_handler.setLevel(logging.INFO) 26 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 27 | console_handler.setFormatter(formatter) 28 | logger.addHandler(console_handler) 29 | 30 | # COLOR CODES 31 | RESET = "\u001b[0m" 32 | RED = "\u001b[31m" 33 | GREEN = "\u001b[32m" 34 | BLUE = "\u001b[34m" 35 | MAGENTA = "\u001b[35m" 36 | BG_MAGENTA = "\u001b[45m" 37 | 38 | # ACTION CODES 39 | BOLD = "\u001b[1m" 40 | UNDERLINE = "\u001b[4m" 41 | INFO = f"{BOLD}{BLUE}[INFO]" 42 | WARNING = f"{BOLD}{RED}[WARNING]" 43 | 44 | 45 | def color(text, color_code, action_code=None): 46 | if action_code: 47 | return f"{action_code} {color_code}{text}{RESET}" 48 | return f"{color_code}{text}{RESET}" 49 | 50 | 51 | REQUIREMENTS_FILE = "requirements.txt" 52 | DOCKERFILE_TEMPLATE = """ARG PYTHON_VERSION=3.12 53 | FROM python:$PYTHON_VERSION-slim 54 | 55 | ####### Add your own installation commands here ####### 56 | # RUN pip install some-package 57 | # RUN wget https://path/to/some/data/or/weights 58 | # RUN apt-get update && apt-get install -y 59 | 60 | WORKDIR /app 61 | COPY . /app 62 | 63 | # Install litserve and requirements 64 | RUN pip install --no-cache-dir litserve=={version} {requirements} 65 | EXPOSE {port} 66 | CMD ["python", "/app/{server_filename}"] 67 | """ 68 | 69 | CUDA_DOCKER_TEMPLATE = """# Change CUDA and cuDNN version here 70 | FROM nvidia/cuda:12.4.1-base-ubuntu22.04 71 | ARG PYTHON_VERSION=3.12 72 | 73 | ENV DEBIAN_FRONTEND=noninteractive 74 | RUN apt-get update && apt-get install -y --no-install-recommends \\ 75 | software-properties-common \\ 76 | wget \\ 77 | && add-apt-repository ppa:deadsnakes/ppa \\ 78 | && apt-get update && apt-get install -y --no-install-recommends \\ 79 | python$PYTHON_VERSION \\ 80 | python$PYTHON_VERSION-dev \\ 81 | python$PYTHON_VERSION-venv \\ 82 | && wget https://bootstrap.pypa.io/get-pip.py -O get-pip.py \\ 83 | && python$PYTHON_VERSION get-pip.py \\ 84 | && rm get-pip.py \\ 85 | && ln -sf /usr/bin/python$PYTHON_VERSION /usr/bin/python \\ 86 | && ln -sf /usr/local/bin/pip$PYTHON_VERSION /usr/local/bin/pip \\ 87 | && python --version \\ 88 | && pip --version \\ 89 | && apt-get purge -y --auto-remove software-properties-common \\ 90 | && apt-get clean \\ 91 | && rm -rf /var/lib/apt/lists/* 92 | 93 | ####### Add your own installation commands here ####### 94 | # RUN pip install some-package 95 | # RUN wget https://path/to/some/data/or/weights 96 | # RUN apt-get update && apt-get install -y 97 | 98 | WORKDIR /app 99 | COPY . /app 100 | 101 | # Install litserve and requirements 102 | RUN pip install --no-cache-dir litserve=={version} {requirements} 103 | EXPOSE {port} 104 | CMD ["python", "/app/{server_filename}"] 105 | """ 106 | 107 | # Link our documentation as the bottom of this msg 108 | SUCCESS_MSG = """{BOLD}{MAGENTA}Dockerfile created successfully{RESET} 109 | Update {UNDERLINE}{dockerfile_path}{RESET} to add any additional dependencies or commands.{RESET} 110 | 111 | {BOLD}Build the container with:{RESET} 112 | > {UNDERLINE}docker build -t litserve-model .{RESET} 113 | 114 | {BOLD}To run the Docker container on the machine:{RESET} 115 | > {UNDERLINE}{RUN_CMD}{RESET} 116 | 117 | {BOLD}To push the container to a registry:{RESET} 118 | > {UNDERLINE}docker push litserve-model{RESET} 119 | """ 120 | 121 | 122 | def dockerize(server_filename: str, port: int = 8000, gpu: bool = False): 123 | """Generate a Dockerfile for the given server code. 124 | 125 | Example usage: 126 | litserve dockerize server.py --port 8000 --gpu 127 | 128 | Args: 129 | server_filename (str): The path to the server file. Example sever.py or app.py. 130 | port (int, optional): The port to expose in the Docker container. 131 | gpu (bool, optional): Whether to use a GPU-enabled Docker image. 132 | 133 | """ 134 | requirements = "" 135 | if os.path.exists(REQUIREMENTS_FILE): 136 | requirements = f"-r {REQUIREMENTS_FILE}" 137 | else: 138 | warnings.warn( 139 | f"requirements.txt not found at {os.getcwd()}. " 140 | f"Make sure to install the required packages in the Dockerfile.", 141 | UserWarning, 142 | ) 143 | 144 | current_dir = Path.cwd() 145 | if not (current_dir / server_filename).is_file(): 146 | raise FileNotFoundError(f"Server file `{server_filename}` must be in the current directory: {os.getcwd()}") 147 | 148 | version = ls.__version__ 149 | if gpu: 150 | run_cmd = f"docker run --gpus all -p {port}:{port} litserve-model:latest" 151 | docker_template = CUDA_DOCKER_TEMPLATE 152 | else: 153 | run_cmd = f"docker run -p {port}:{port} litserve-model:latest" 154 | docker_template = DOCKERFILE_TEMPLATE 155 | dockerfile_content = docker_template.format( 156 | server_filename=server_filename, 157 | port=port, 158 | version=version, 159 | requirements=requirements, 160 | ) 161 | with open("Dockerfile", "w") as f: 162 | f.write(dockerfile_content) 163 | success_msg = SUCCESS_MSG.format( 164 | dockerfile_path=os.path.abspath("Dockerfile"), 165 | RUN_CMD=run_cmd, 166 | BOLD=BOLD, 167 | MAGENTA=MAGENTA, 168 | GREEN=GREEN, 169 | BLUE=BLUE, 170 | UNDERLINE=UNDERLINE, 171 | BG_MAGENTA=BG_MAGENTA, 172 | RESET=RESET, 173 | ) 174 | print(success_msg) 175 | -------------------------------------------------------------------------------- /src/litserve/loggers.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import functools 15 | import logging 16 | import multiprocessing as mp 17 | import pickle 18 | from abc import ABC, abstractmethod 19 | from typing import TYPE_CHECKING, List, Optional, Union 20 | 21 | from starlette.types import ASGIApp 22 | 23 | module_logger = logging.getLogger(__name__) 24 | 25 | if TYPE_CHECKING: # pragma: no cover 26 | from litserve import LitServer 27 | 28 | 29 | class Logger(ABC): 30 | def __init__(self): 31 | self._config = {} 32 | 33 | def mount(self, path: str, app: ASGIApp) -> None: 34 | """Mount an ASGI app endpoint to LitServer. Use this method when you want to add an additional endpoint to the 35 | server such as /metrics endpoint for prometheus metrics. 36 | 37 | Args: 38 | path (str): The path to mount the app to. 39 | app (ASGIApp): The ASGI app to mount. 40 | 41 | """ 42 | self._config.update({"mount": {"path": path, "app": app}}) 43 | 44 | @abstractmethod 45 | def process(self, key, value): 46 | """Process a log entry from the log queue. 47 | 48 | This method should be implemented to define the specific logic for processing 49 | log entries. 50 | 51 | Args: 52 | key (str): The key associated with the log entry, typically indicating the type or category of the log. 53 | value (Any): The value associated with the log entry, containing the actual log data. 54 | 55 | Raises: 56 | NotImplementedError: This method must be overridden by subclasses. If not, calling this method will raise 57 | a NotImplementedError. 58 | 59 | Example: 60 | Here is an example of a Logger that logs monitoring metrics using Prometheus: 61 | 62 | from prometheus_client import Counter 63 | 64 | class PrometheusLogger(Logger): 65 | def __init__(self): 66 | super().__init__() 67 | self._metric_counter = Counter('log_entries', 'Count of log entries') 68 | 69 | def process(self, key, value): 70 | # Increment the Prometheus counter for each log entry 71 | self._metric_counter.inc() 72 | print(f"Logged {key}: {value}") 73 | 74 | """ 75 | raise NotImplementedError # pragma: no cover 76 | 77 | 78 | class _LoggerProxy: 79 | def __init__(self, logger_class): 80 | self.logger_class = logger_class 81 | 82 | def create_logger(self): 83 | return self.logger_class() 84 | 85 | 86 | class _LoggerConnector: 87 | """_LoggerConnector is responsible for connecting Logger instances with the LitServer and managing their lifecycle. 88 | 89 | This class handles the following tasks: 90 | - Manages a queue (multiprocessing.Queue) where log data is placed using the LitAPI.log method. 91 | - Initiates a separate process to consume the log queue and process the log data using the associated 92 | Logger instances. 93 | 94 | """ 95 | 96 | def __init__(self, lit_server: "LitServer", loggers: Optional[Union[List[Logger], Logger]] = None): 97 | self._loggers = [] 98 | self._lit_server = lit_server 99 | if loggers is None: 100 | return # No loggers to add 101 | if isinstance(loggers, list): 102 | for logger in loggers: 103 | if not isinstance(logger, Logger): 104 | raise ValueError("Logger must be an instance of litserve.Logger") 105 | self.add_logger(logger) 106 | elif isinstance(loggers, Logger): 107 | self.add_logger(loggers) 108 | else: 109 | raise ValueError("loggers must be a list or an instance of litserve.Logger") 110 | 111 | def _mount(self, path: str, app: ASGIApp) -> None: 112 | self._lit_server.app.mount(path, app) 113 | 114 | def add_logger(self, logger: Logger): 115 | self._loggers.append(logger) 116 | if "mount" in logger._config: 117 | self._mount(logger._config["mount"]["path"], logger._config["mount"]["app"]) 118 | 119 | @staticmethod 120 | def _is_picklable(obj): 121 | try: 122 | pickle.dumps(obj) 123 | return True 124 | except (pickle.PicklingError, TypeError, AttributeError): 125 | module_logger.warning(f"Logger {obj.__class__.__name__} is not pickleable and might not work properly.") 126 | return False 127 | 128 | @staticmethod 129 | def _process_logger_queue(logger_proxies: List[_LoggerProxy], queue): 130 | loggers = [proxy if isinstance(proxy, Logger) else proxy.create_logger() for proxy in logger_proxies] 131 | while True: 132 | key, value = queue.get() 133 | for logger in loggers: 134 | try: 135 | logger.process(key, value) 136 | except Exception as e: 137 | module_logger.error( 138 | f"{logger.__class__.__name__} ran into an error while processing log for entry " 139 | f"with key {key} and value {value}: {e}" 140 | ) 141 | 142 | @functools.cache # Run once per LitServer instance 143 | def run(self, lit_server: "LitServer"): 144 | queue = lit_server.logger_queue 145 | lit_server.litapi_connector.set_logger_queue(queue) 146 | 147 | # Disconnect the logger connector from the LitServer to avoid pickling issues 148 | self._lit_server = None 149 | 150 | if not self._loggers: 151 | return 152 | 153 | # Create proxies for loggers 154 | logger_proxies = [] 155 | for logger in self._loggers: 156 | if self._is_picklable(logger): 157 | logger_proxies.append(logger) 158 | else: 159 | logger_proxies.append(_LoggerProxy(logger.__class__)) 160 | 161 | module_logger.debug(f"Starting logger process with {len(logger_proxies)} loggers") 162 | ctx = mp.get_context("spawn") 163 | process = ctx.Process( 164 | target=_LoggerConnector._process_logger_queue, 165 | args=( 166 | logger_proxies, 167 | queue, 168 | ), 169 | ) 170 | process.start() 171 | -------------------------------------------------------------------------------- /src/litserve/loops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import multiprocessing as mp 15 | 16 | from litserve.loops.base import LitLoop, _BaseLoop 17 | from litserve.loops.continuous_batching_loop import ContinuousBatchingLoop, Output 18 | from litserve.loops.loops import ( 19 | get_default_loop, 20 | inference_worker, 21 | ) 22 | from litserve.loops.simple_loops import BatchedLoop, SingleLoop 23 | from litserve.loops.streaming_loops import BatchedStreamingLoop, StreamingLoop 24 | 25 | mp.allow_connection_pickling() 26 | 27 | __all__ = [ 28 | "_BaseLoop", 29 | "LitLoop", 30 | "get_default_loop", 31 | "inference_worker", 32 | "Output", 33 | "SingleLoop", 34 | "BatchedLoop", 35 | "StreamingLoop", 36 | "BatchedStreamingLoop", 37 | "ContinuousBatchingLoop", 38 | ] 39 | -------------------------------------------------------------------------------- /src/litserve/loops/loops.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | from queue import Queue 16 | from typing import Dict 17 | 18 | from litserve import LitAPI 19 | from litserve.callbacks import CallbackRunner, EventTypes 20 | from litserve.loops.base import LitLoop, _BaseLoop 21 | from litserve.loops.simple_loops import BatchedLoop, SingleLoop 22 | from litserve.loops.streaming_loops import BatchedStreamingLoop, StreamingLoop 23 | from litserve.transport.base import MessageTransport 24 | from litserve.utils import WorkerSetupStatus 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | def get_default_loop(stream: bool, max_batch_size: int, enable_async: bool = False) -> _BaseLoop: 30 | """Get the default loop based on the stream flag, batch size, and async support. 31 | 32 | Args: 33 | stream: Whether streaming is enabled 34 | max_batch_size: Maximum batch size 35 | enable_async: Whether async support is enabled (supports both coroutines and async generators) 36 | 37 | Returns: 38 | The appropriate loop implementation 39 | 40 | Raises: 41 | ValueError: If async and batching are enabled together (not supported) 42 | 43 | """ 44 | if enable_async: 45 | if max_batch_size > 1: 46 | raise ValueError("Async batching is not supported. Please use enable_async=False with batching.") 47 | if stream: 48 | return StreamingLoop() # StreamingLoop now supports async 49 | return SingleLoop() # Only SingleLoop supports async currently 50 | 51 | if stream: 52 | if max_batch_size > 1: 53 | return BatchedStreamingLoop() 54 | return StreamingLoop() 55 | 56 | if max_batch_size > 1: 57 | return BatchedLoop() 58 | return SingleLoop() 59 | 60 | 61 | def inference_worker( 62 | lit_api: LitAPI, 63 | device: str, 64 | worker_id: int, 65 | request_queue: Queue, 66 | transport: MessageTransport, 67 | workers_setup_status: Dict[int, str], 68 | callback_runner: CallbackRunner, 69 | ): 70 | print("workers_setup_status", workers_setup_status) 71 | lit_spec = lit_api.spec 72 | loop: LitLoop = lit_api.loop 73 | stream = lit_api.stream 74 | 75 | endpoint = lit_api.api_path.split("/")[-1] 76 | 77 | callback_runner.trigger_event(EventTypes.BEFORE_SETUP.value, lit_api=lit_api) 78 | try: 79 | lit_api.setup(device) 80 | except Exception: 81 | logger.exception(f"Error setting up worker {worker_id}.") 82 | workers_setup_status[f"{endpoint}_{worker_id}"] = WorkerSetupStatus.ERROR 83 | return 84 | lit_api.device = device 85 | callback_runner.trigger_event(EventTypes.AFTER_SETUP.value, lit_api=lit_api) 86 | 87 | print(f"Setup complete for worker {f'{endpoint}_{worker_id}'}.") 88 | 89 | if workers_setup_status: 90 | workers_setup_status[f"{endpoint}_{worker_id}"] = WorkerSetupStatus.READY 91 | 92 | if lit_spec: 93 | logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") 94 | 95 | if loop == "auto": 96 | loop = get_default_loop(stream, lit_api.max_batch_size, lit_api.enable_async) 97 | 98 | loop( 99 | lit_api, 100 | device, 101 | worker_id, 102 | request_queue, 103 | transport, 104 | workers_setup_status, 105 | callback_runner, 106 | ) 107 | -------------------------------------------------------------------------------- /src/litserve/middlewares.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | import multiprocessing 16 | from typing import Optional 17 | 18 | from fastapi import HTTPException 19 | from starlette.middleware.base import BaseHTTPMiddleware 20 | from starlette.types import ASGIApp, Message, Receive, Scope, Send 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class MaxSizeMiddleware(BaseHTTPMiddleware): 26 | """Rejects requests with a payload that is too large.""" 27 | 28 | def __init__( 29 | self, 30 | app: ASGIApp, 31 | *, 32 | max_size: Optional[int] = None, 33 | ) -> None: 34 | self.app = app 35 | self.max_size = max_size 36 | 37 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 38 | if scope["type"] != "http": 39 | await self.app(scope, receive, send) 40 | return 41 | 42 | total_size = 0 43 | 44 | async def rcv() -> Message: 45 | nonlocal total_size 46 | message = await receive() 47 | chunk_size = len(message.get("body", b"")) 48 | total_size += chunk_size 49 | if self.max_size is not None and total_size > self.max_size: 50 | raise HTTPException(413, "Payload too large") 51 | return message 52 | 53 | await self.app(scope, rcv, send) 54 | 55 | 56 | class RequestCountMiddleware(BaseHTTPMiddleware): 57 | """Adds a header to the response with the number of active requests.""" 58 | 59 | def __init__(self, app: ASGIApp, active_counter: multiprocessing.Value) -> None: 60 | self.app = app 61 | self.active_counter = active_counter 62 | 63 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 64 | if scope["type"] != "http" or (scope["type"] == "http" and scope["path"] in ["/", "/health", "/metrics"]): 65 | await self.app(scope, receive, send) 66 | return 67 | 68 | self.active_counter.value += 1 69 | await self.app(scope, receive, send) 70 | self.active_counter.value -= 1 71 | -------------------------------------------------------------------------------- /src/litserve/python_client.py: -------------------------------------------------------------------------------- 1 | client_template = """# This file is auto-generated by LitServe. 2 | # Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`. 3 | 4 | import requests 5 | 6 | response = requests.post("http://127.0.0.1:{PORT}/predict", json={{"input": 4.0}}) 7 | print(f"Status: {{response.status_code}}\\nResponse:\\n {{response.text}}") 8 | """ 9 | -------------------------------------------------------------------------------- /src/litserve/schema/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-AI/LitServe/ea6aca44640d4d5691ba44fb27e76796d83482de/src/litserve/schema/__init__.py -------------------------------------------------------------------------------- /src/litserve/schema/image.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | from typing import TYPE_CHECKING, Any, Optional 4 | 5 | from pydantic import BaseModel, field_serializer, model_validator 6 | 7 | if TYPE_CHECKING: 8 | from PIL import Image 9 | 10 | 11 | class ImageInput(BaseModel): 12 | image_data: Optional[str] = None 13 | 14 | @model_validator(mode="after") 15 | def validate_base64(self) -> "ImageInput": 16 | """Ensure the string is a valid Base64.""" 17 | model_dump = self.model_dump() 18 | for key, value in model_dump.items(): 19 | if value: 20 | try: 21 | base64.b64decode(value) 22 | except base64.binascii.Error: 23 | raise ValueError("Invalid Base64 string.") 24 | return self 25 | 26 | def get_image(self, key: Optional[str] = None) -> "Image.Image": 27 | """Decode the Base64 string and return a PIL Image object.""" 28 | if key is None: 29 | key = "image_data" 30 | image_data = self.model_dump().get(key) 31 | if not image_data: 32 | raise ValueError(f"Missing image data for key '{key}'") 33 | try: 34 | from PIL import Image, UnidentifiedImageError 35 | except ImportError: 36 | raise ImportError("Pillow is required to use the ImageInput schema. Install it with `pip install Pillow`.") 37 | try: 38 | decoded_data = base64.b64decode(image_data) 39 | return Image.open(BytesIO(decoded_data)) 40 | except UnidentifiedImageError as e: 41 | raise ValueError(f"Error loading image from decoded data: {e}") 42 | 43 | 44 | class ImageOutput(BaseModel): 45 | image: Any 46 | 47 | @field_serializer("image") 48 | def serialize_image(self, image: Any, _info): 49 | """ 50 | Serialize a PIL Image into a base64 string. 51 | Args: 52 | image (Any): The image object to serialize. 53 | _info: Metadata passed during serialization (not used here). 54 | 55 | Returns: 56 | str: Base64-encoded image string. 57 | """ 58 | try: 59 | from PIL import Image 60 | except ImportError: 61 | raise ImportError("Pillow is required to use the ImageOutput schema. Install it with `pip install Pillow`.") 62 | 63 | if not isinstance(image, Image.Image): 64 | raise TypeError(f"Expected a PIL Image, got {type(image)}") 65 | 66 | # Save the image to a BytesIO buffer 67 | buffer = BytesIO() 68 | image.save(buffer, format="PNG") # Default format is PNG 69 | buffer.seek(0) 70 | 71 | # Encode the buffer content to base64 72 | base64_bytes = base64.b64encode(buffer.read()) 73 | 74 | # Decode to string for JSON serialization 75 | return base64_bytes.decode("utf-8") 76 | -------------------------------------------------------------------------------- /src/litserve/specs/__init__.py: -------------------------------------------------------------------------------- 1 | from litserve.specs.openai import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, OpenAISpec 2 | from litserve.specs.openai_embedding import EmbeddingRequest, EmbeddingResponse, OpenAIEmbeddingSpec 3 | 4 | __all__ = [ 5 | "OpenAISpec", 6 | "OpenAIEmbeddingSpec", 7 | "EmbeddingRequest", 8 | "EmbeddingResponse", 9 | "ChatCompletionRequest", 10 | "ChatCompletionResponse", 11 | "ChatCompletionChunk", 12 | ] 13 | -------------------------------------------------------------------------------- /src/litserve/specs/base.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import abstractmethod 15 | from typing import TYPE_CHECKING, Callable, List 16 | 17 | if TYPE_CHECKING: 18 | from litserve import LitAPI, LitServer 19 | 20 | 21 | class LitSpec: 22 | """Spec will have its own encode, and decode.""" 23 | 24 | def __init__(self): 25 | self._endpoints = [] 26 | self.api_path = None 27 | self._server: LitServer = None 28 | self._max_batch_size = 1 29 | self.response_buffer = None 30 | self.request_queue = None 31 | self.response_queue_id = None 32 | 33 | @property 34 | def stream(self): 35 | return False 36 | 37 | def pre_setup(self, lit_api: "LitAPI"): 38 | pass 39 | 40 | def setup(self, server: "LitServer"): 41 | """This method is called by the server to connect the spec to the server.""" 42 | self.response_buffer = server.response_buffer 43 | self.request_queue = server._get_request_queue(self.api_path) 44 | self.data_streamer = server.data_streamer 45 | 46 | def add_endpoint(self, path: str, endpoint: Callable, methods: List[str]): 47 | """Register an endpoint in the spec.""" 48 | self._endpoints.append((path, endpoint, methods)) 49 | 50 | @property 51 | def endpoints(self): 52 | return self._endpoints.copy() 53 | 54 | @abstractmethod 55 | def decode_request(self, request, meta_kwargs): 56 | """Convert the request payload to your model input.""" 57 | pass 58 | 59 | @abstractmethod 60 | def encode_response(self, output, meta_kwargs): 61 | """Convert the model output to a response payload. 62 | 63 | To enable streaming, it should yield the output. 64 | 65 | """ 66 | pass 67 | -------------------------------------------------------------------------------- /src/litserve/test_examples/__init__.py: -------------------------------------------------------------------------------- 1 | from litserve.test_examples.openai_spec_example import ( 2 | OpenAIBatchContext, 3 | TestAPI, 4 | TestAPIWithCustomEncode, 5 | TestAPIWithStructuredOutput, 6 | TestAPIWithToolCalls, 7 | ) 8 | from litserve.test_examples.simple_example import SimpleBatchedAPI, SimpleLitAPI, SimpleStreamAPI, SimpleTorchAPI 9 | 10 | __all__ = [ 11 | "SimpleLitAPI", 12 | "SimpleBatchedAPI", 13 | "SimpleTorchAPI", 14 | "TestAPI", 15 | "TestAPIWithCustomEncode", 16 | "TestAPIWithStructuredOutput", 17 | "TestAPIWithToolCalls", 18 | "OpenAIBatchContext", 19 | "SimpleStreamAPI", 20 | ] 21 | -------------------------------------------------------------------------------- /src/litserve/test_examples/openai_embedding_spec_example.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | from litserve.api import LitAPI 6 | 7 | 8 | class TestEmbedAPI(LitAPI): 9 | def setup(self, device): 10 | self.model = None 11 | 12 | def predict(self, x) -> List[List[float]]: 13 | n = len(x) if isinstance(x, list) else 1 14 | return np.random.rand(n, 768).tolist() 15 | 16 | def encode_response(self, output) -> dict: 17 | return {"embeddings": output} 18 | 19 | 20 | class TestEmbedBatchedAPI(TestEmbedAPI): 21 | def predict(self, batch) -> List[List[List[float]]]: 22 | return [np.random.rand(len(x), 768).tolist() for x in batch] 23 | 24 | 25 | class TestEmbedAPIWithUsage(TestEmbedAPI): 26 | def encode_response(self, output) -> dict: 27 | return {"embeddings": output, "prompt_tokens": 10, "total_tokens": 10} 28 | 29 | 30 | class TestEmbedAPIWithYieldPredict(TestEmbedAPI): 31 | def predict(self, x): 32 | yield from np.random.rand(768).tolist() 33 | 34 | 35 | class TestEmbedAPIWithYieldEncodeResponse(TestEmbedAPI): 36 | def encode_response(self, output): 37 | yield {"embeddings": output} 38 | 39 | 40 | class TestEmbedAPIWithNonDictOutput(TestEmbedAPI): 41 | def encode_response(self, output): 42 | return output 43 | 44 | 45 | class TestEmbedAPIWithMissingEmbeddings(TestEmbedAPI): 46 | def encode_response(self, output): 47 | return {"output": output} 48 | -------------------------------------------------------------------------------- /src/litserve/test_examples/openai_spec_example.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import time 15 | 16 | from litserve.api import LitAPI 17 | from litserve.specs.openai import ChatMessage 18 | 19 | 20 | class TestAPI(LitAPI): 21 | def setup(self, device): 22 | self.model = None 23 | 24 | def predict(self, x): 25 | yield "This is a generated output" 26 | 27 | 28 | class TestAPIWithCustomEncode(TestAPI): 29 | def encode_response(self, output): 30 | yield ChatMessage(role="assistant", content="This is a custom encoded output") 31 | 32 | 33 | class TestAPIWithToolCalls(TestAPI): 34 | def encode_response(self, output): 35 | yield ChatMessage( 36 | role="assistant", 37 | content="", 38 | tool_calls=[ 39 | { 40 | "id": "call_1", 41 | "type": "function", 42 | "function": {"name": "function_1", "arguments": '{"arg_1": "arg_1_value"}'}, 43 | } 44 | ], 45 | ) 46 | 47 | 48 | class TestAPIWithStructuredOutput(TestAPI): 49 | def encode_response(self, output): 50 | yield ChatMessage( 51 | role="assistant", 52 | content='{"name": "Science Fair", "date": "Friday", "participants": ["Alice", "Bob"]}', 53 | ) 54 | 55 | 56 | class OpenAIBatchContext(LitAPI): 57 | def setup(self, device: str) -> None: 58 | self.model = None 59 | 60 | def batch(self, inputs): 61 | return inputs 62 | 63 | def predict(self, inputs, context): 64 | n = len(inputs) 65 | assert isinstance(context, list) 66 | for ctx in context: 67 | ctx["temperature"] = 1.0 68 | output = [ 69 | "Hi!", 70 | "It's", 71 | "nice", 72 | "to", 73 | "meet", 74 | "you.", 75 | "Is", 76 | "there", 77 | "something", 78 | "I", 79 | "can", 80 | "help", 81 | "you", 82 | "with", 83 | "or", 84 | "would", 85 | "you", 86 | "like", 87 | "to", 88 | "chat?", 89 | ] 90 | for out in output: 91 | time.sleep(0.01) # fake delay 92 | yield [out + " "] * n 93 | 94 | def unbatch(self, output): 95 | return output 96 | 97 | def encode_response(self, output_stream, context): 98 | for outputs in output_stream: 99 | yield [{"role": "assistant", "content": output} for output in outputs] 100 | for ctx in context: 101 | assert ctx["temperature"] == 1.0, f"context {ctx} is not 1.0" 102 | 103 | 104 | class OpenAIWithUsage(LitAPI): 105 | def setup(self, device): 106 | self.model = None 107 | 108 | def predict(self, x): 109 | yield { 110 | "role": "assistant", 111 | "content": "10 + 6 is equal to 16.", 112 | "prompt_tokens": 25, 113 | "completion_tokens": 10, 114 | "total_tokens": 35, 115 | } 116 | 117 | 118 | class OpenAIWithUsageEncodeResponse(LitAPI): 119 | def setup(self, device): 120 | self.model = None 121 | 122 | def predict(self, x): 123 | # streaming tokens 124 | yield from ["10", " +", " ", "6", " is", " equal", " to", " ", "16", "."] 125 | 126 | def encode_response(self, output): 127 | for out in output: 128 | yield {"role": "assistant", "content": out} 129 | 130 | yield {"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35} 131 | 132 | 133 | class OpenAIBatchingWithUsage(OpenAIWithUsage): 134 | def batch(self, inputs): 135 | return inputs 136 | 137 | def predict(self, x): 138 | n = len(x) 139 | yield ["10 + 6 is equal to 16."] * n 140 | 141 | def encode_response(self, output_stream_batch, context): 142 | n = len(context) 143 | for output_batch in output_stream_batch: 144 | yield [{"role": "assistant", "content": out} for out in output_batch] 145 | 146 | yield [ 147 | {"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35} 148 | ] * n 149 | 150 | def unbatch(self, output): 151 | return output 152 | -------------------------------------------------------------------------------- /src/litserve/test_examples/simple_example.py: -------------------------------------------------------------------------------- 1 | from litserve.api import LitAPI 2 | 3 | 4 | class SimpleLitAPI(LitAPI): 5 | def setup(self, device): 6 | # Set up the model, so it can be called in `predict`. 7 | self.model = lambda x: x**2 8 | 9 | def decode_request(self, request): 10 | # Convert the request payload to your model input. 11 | return request["input"] 12 | 13 | def predict(self, x): 14 | # Run the model on the input and return the output. 15 | return self.model(x) 16 | 17 | def encode_response(self, output): 18 | # Convert the model output to a response payload. 19 | return {"output": output} 20 | 21 | 22 | class SimpleBatchedAPI(LitAPI): 23 | def setup(self, device) -> None: 24 | self.model = lambda x: x**2 25 | 26 | def decode_request(self, request): 27 | import numpy as np 28 | 29 | return np.asarray(request["input"]) 30 | 31 | def predict(self, x): 32 | return self.model(x) 33 | 34 | def encode_response(self, output): 35 | return {"output": output} 36 | 37 | 38 | class SimpleTorchAPI(LitAPI): 39 | def setup(self, device): 40 | # move the model to the correct device 41 | # keep track of the device for moving data accordingly 42 | import torch.nn as nn 43 | 44 | class Linear(nn.Module): 45 | def __init__(self): 46 | super().__init__() 47 | self.linear = nn.Linear(1, 1) 48 | self.linear.weight.data.fill_(2.0) 49 | self.linear.bias.data.fill_(1.0) 50 | 51 | def forward(self, x): 52 | return self.linear(x) 53 | 54 | self.model = Linear().to(device) 55 | 56 | def decode_request(self, request): 57 | import torch 58 | 59 | # get the input and create a 1D tensor on the correct device 60 | content = request["input"] 61 | return torch.tensor([content], device=self.device) 62 | 63 | def predict(self, x): 64 | # the model expects a batch dimension, so create it 65 | return self.model(x[None, :]) 66 | 67 | def encode_response(self, output): 68 | # float will take the output value directly onto CPU memory 69 | return {"output": float(output)} 70 | 71 | 72 | class SimpleStreamAPI(LitAPI): 73 | """ 74 | Run as: 75 | ``` 76 | server = ls.LitServer(SimpleStreamAPI(), stream=True) 77 | server.run(port=8000) 78 | ``` 79 | Then, in a new Python session, retrieve the responses as follows: 80 | ``` 81 | import requests 82 | url = "http://127.0.0.1:8000/predict" 83 | resp = requests.post(url, json={"input": "Hello world"}, headers=None, stream=True) 84 | for line in resp.iter_content(5000): 85 | if line: 86 | print(line.decode("utf-8")) 87 | ``` 88 | """ 89 | 90 | def setup(self, device) -> None: 91 | self.model = lambda x, y: f"{x}: {y}" 92 | 93 | def decode_request(self, request): 94 | return request["input"] 95 | 96 | def predict(self, x): 97 | for i in range(3): 98 | yield self.model(i, x) 99 | 100 | def encode_response(self, output_stream): 101 | for output in output_stream: 102 | yield {"output": output} 103 | -------------------------------------------------------------------------------- /src/litserve/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .process_transport import MPQueueTransport 2 | from .zmq_transport import ZMQTransport 3 | 4 | __all__ = ["ZMQTransport", "MPQueueTransport"] 5 | -------------------------------------------------------------------------------- /src/litserve/transport/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Optional 3 | 4 | 5 | class MessageTransport(ABC): 6 | @abstractmethod 7 | def send(self, item: Any, consumer_id: int) -> None: 8 | """Send a message to a consumer in the main process.""" 9 | pass 10 | 11 | @abstractmethod 12 | async def areceive(self, timeout: Optional[int] = None, consumer_id: Optional[int] = None) -> dict: 13 | """Receive a message from model workers or any publisher.""" 14 | pass 15 | 16 | def close(self) -> None: 17 | """Clean up resources if needed (e.g., sockets, processes).""" 18 | pass 19 | -------------------------------------------------------------------------------- /src/litserve/transport/factory.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Manager 2 | from typing import Literal, Optional 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | from litserve.transport.process_transport import MPQueueTransport 7 | from litserve.transport.zmq_queue import Broker 8 | from litserve.transport.zmq_transport import ZMQTransport 9 | 10 | 11 | class TransportConfig(BaseModel): 12 | transport_type: Literal["mp", "zmq"] = "mp" 13 | num_consumers: int = Field(1, ge=1) 14 | manager: Optional[Manager] = None 15 | consumer_id: Optional[int] = None 16 | frontend_address: Optional[str] = None 17 | backend_address: Optional[str] = None 18 | 19 | 20 | def _create_zmq_transport(config: TransportConfig): 21 | broker = Broker() 22 | broker.start() 23 | config.frontend_address = broker.frontend_address 24 | config.backend_address = broker.backend_address 25 | return ZMQTransport(config.frontend_address, config.backend_address) 26 | 27 | 28 | def _create_mp_transport(config: TransportConfig): 29 | queues = [config.manager.Queue() for _ in range(config.num_consumers)] 30 | return MPQueueTransport(config.manager, queues) 31 | 32 | 33 | def create_transport_from_config(config: TransportConfig): 34 | if config.transport_type == "mp": 35 | return _create_mp_transport(config) 36 | if config.transport_type == "zmq": 37 | return _create_zmq_transport(config) 38 | raise ValueError(f"Invalid transport type: {config.transport_type}") 39 | -------------------------------------------------------------------------------- /src/litserve/transport/process_transport.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from contextlib import suppress 3 | from multiprocessing import Manager, Queue 4 | from typing import Any, List, Optional 5 | 6 | from litserve.transport.base import MessageTransport 7 | 8 | 9 | class MPQueueTransport(MessageTransport): 10 | def __init__(self, manager: Manager, queues: List[Queue]): 11 | self._queues = queues 12 | self._closed = False 13 | 14 | def send(self, item: Any, consumer_id: int) -> None: 15 | return self._queues[consumer_id].put(item) 16 | 17 | async def areceive(self, consumer_id: int, timeout: Optional[float] = None, block: bool = True) -> dict: 18 | if self._closed: 19 | raise asyncio.CancelledError("Transport closed") 20 | 21 | actual_timeout = 1 if timeout is None else min(timeout, 1) 22 | 23 | try: 24 | return await asyncio.to_thread(self._queues[consumer_id].get, timeout=actual_timeout, block=True) 25 | except asyncio.CancelledError: 26 | raise 27 | except Exception: 28 | if self._closed: 29 | raise asyncio.CancelledError("Transport closed") 30 | if timeout is not None and timeout <= actual_timeout: 31 | raise 32 | return None 33 | 34 | def close(self) -> None: 35 | # Mark the transport as closed 36 | self._closed = True 37 | 38 | # Put sentinel values in the queues as a backup mechanism 39 | for queue in self._queues: 40 | with suppress(Exception): 41 | queue.put(None) 42 | 43 | def __reduce__(self): 44 | return (MPQueueTransport, (None, self._queues)) 45 | -------------------------------------------------------------------------------- /src/litserve/transport/zmq_queue.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import asyncio 15 | import logging 16 | import multiprocessing 17 | import pickle 18 | import threading 19 | import time 20 | from queue import Empty 21 | from typing import Any, Optional 22 | 23 | import zmq 24 | import zmq.asyncio 25 | 26 | from litserve.utils import generate_random_zmq_address 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class Broker: 32 | """Message broker that routes messages between producers and consumers.""" 33 | 34 | def __init__(self, use_process: bool = False): 35 | self.frontend_address = generate_random_zmq_address() 36 | self.backend_address = generate_random_zmq_address() 37 | self._running = False 38 | self._use_process = use_process 39 | self._worker = None 40 | 41 | def start(self): 42 | """Start the broker in a background thread or process.""" 43 | self._running = True 44 | 45 | if self._use_process: 46 | self._worker = multiprocessing.Process(target=self._run) 47 | else: 48 | self._worker = threading.Thread(target=self._run) 49 | 50 | self._worker.daemon = True 51 | self._worker.start() 52 | logger.info( 53 | f"Broker started in {'process' if self._use_process else 'thread'} " 54 | f"on {self.frontend_address} (frontend) and {self.backend_address} (backend)" 55 | ) 56 | time.sleep(0.1) # Give the broker time to start 57 | 58 | def _run(self): 59 | """Main broker loop.""" 60 | context = zmq.Context() 61 | try: 62 | frontend = context.socket(zmq.XPUB) 63 | frontend.bind(self.frontend_address) 64 | 65 | backend = context.socket(zmq.XSUB) 66 | backend.bind(self.backend_address) 67 | 68 | zmq.proxy(frontend, backend) 69 | except zmq.ZMQError as e: 70 | logger.error(f"Broker error: {e}") 71 | finally: 72 | frontend.close(linger=0) 73 | backend.close(linger=0) 74 | context.term() 75 | 76 | def stop(self): 77 | """Stop the broker.""" 78 | self._running = False 79 | if self._worker: 80 | self._worker.join() 81 | 82 | 83 | class Producer: 84 | """Producer class for sending messages to consumers.""" 85 | 86 | def __init__(self, address: str = None): 87 | self._context = zmq.Context() 88 | self._socket = self._context.socket(zmq.PUB) 89 | self._socket.connect(address) 90 | 91 | def wait_for_subscribers(self, timeout: float = 1.0) -> bool: 92 | """Wait for at least one subscriber to be ready. 93 | 94 | Args: 95 | timeout: Maximum time to wait in seconds 96 | 97 | Returns: 98 | bool: True if subscribers are ready, False if timeout occurred 99 | 100 | """ 101 | start_time = time.time() 102 | while time.time() - start_time < timeout: 103 | # Send a ping message to consumer 0 (special system messages) 104 | try: 105 | self._socket.send(b"0|__ping__", zmq.NOBLOCK) 106 | time.sleep(0.1) # Give time for subscription to propagate 107 | return True 108 | except zmq.ZMQError: 109 | continue 110 | return False 111 | 112 | def put(self, item: Any, consumer_id: int) -> None: 113 | """Send an item to a specific consumer.""" 114 | try: 115 | pickled_item = pickle.dumps(item) 116 | message = f"{consumer_id}|".encode() + pickled_item 117 | self._socket.send(message) 118 | except zmq.ZMQError as e: 119 | logger.error(f"Error sending item: {e}") 120 | raise 121 | except pickle.PickleError as e: 122 | logger.error(f"Error serializing item: {e}") 123 | raise 124 | 125 | def close(self) -> None: 126 | """Clean up resources.""" 127 | if self._socket: 128 | self._socket.close(linger=0) 129 | if self._context: 130 | self._context.term() 131 | 132 | 133 | class BaseConsumer: 134 | """Base class for consumers.""" 135 | 136 | def __init__(self, consumer_id: int, address: str): 137 | self.consumer_id = consumer_id 138 | self.address = address 139 | self._context = None 140 | self._socket = None 141 | self._setup_socket() 142 | 143 | def _setup_socket(self): 144 | """Setup ZMQ socket - to be implemented by subclasses""" 145 | raise NotImplementedError 146 | 147 | def _parse_message(self, message: bytes) -> Any: 148 | """Parse a message received from ZMQ.""" 149 | try: 150 | consumer_id, pickled_data = message.split(b"|", 1) 151 | return pickle.loads(pickled_data) 152 | except pickle.PickleError as e: 153 | logger.error(f"Error deserializing message: {e}") 154 | raise 155 | 156 | def close(self) -> None: 157 | """Clean up resources.""" 158 | if self._socket: 159 | self._socket.close(linger=0) 160 | if self._context: 161 | self._context.term() 162 | 163 | 164 | class AsyncConsumer(BaseConsumer): 165 | """Async consumer class for receiving messages using asyncio.""" 166 | 167 | def _setup_socket(self): 168 | self._context = zmq.asyncio.Context() 169 | self._socket = self._context.socket(zmq.SUB) 170 | self._socket.connect(self.address) 171 | self._socket.setsockopt_string(zmq.SUBSCRIBE, str(self.consumer_id)) 172 | 173 | async def get(self, timeout: Optional[float] = None) -> Any: 174 | """Get an item from the queue asynchronously.""" 175 | try: 176 | if timeout is not None: 177 | message = await asyncio.wait_for(self._socket.recv(), timeout) 178 | else: 179 | message = await self._socket.recv() 180 | 181 | return self._parse_message(message) 182 | except asyncio.TimeoutError: 183 | raise Empty 184 | except zmq.ZMQError: 185 | raise Empty 186 | 187 | def close(self) -> None: 188 | """Clean up resources asynchronously.""" 189 | if self._socket: 190 | self._socket.close(linger=0) 191 | if self._context: 192 | self._context.term() 193 | -------------------------------------------------------------------------------- /src/litserve/transport/zmq_transport.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Optional, Union 2 | 3 | import zmq 4 | 5 | from litserve.transport.base import MessageTransport 6 | from litserve.transport.zmq_queue import AsyncConsumer, Producer 7 | 8 | 9 | class ZMQTransport(MessageTransport): 10 | def __init__(self, backend_address: str, frontend_address): 11 | self.backend_address = backend_address 12 | self.frontend_address = frontend_address 13 | self._zmq: Union[Producer, AsyncConsumer, None] = None 14 | 15 | def setup(self, operation: Literal[zmq.SUB, zmq.PUB], consumer_id: Optional[int] = None) -> None: 16 | """Must be called in the subprocess to setup the ZMQ transport.""" 17 | if operation == zmq.PUB: 18 | self._zmq = Producer(address=self.backend_address) 19 | self._zmq.wait_for_subscribers() 20 | elif operation == zmq.SUB: 21 | self._zmq = AsyncConsumer(consumer_id=consumer_id, address=self.frontend_address) 22 | else: 23 | ValueError(f"Invalid operation {operation}") 24 | 25 | def send(self, item: Any, consumer_id: int) -> None: 26 | if self._zmq is None: 27 | self.setup(zmq.PUB) 28 | return self._zmq.put(item, consumer_id) 29 | 30 | async def areceive(self, consumer_id: Optional[int] = None, timeout=None) -> dict: 31 | if self._zmq is None: 32 | self.setup(zmq.SUB, consumer_id) 33 | return await self._zmq.get(timeout=timeout) 34 | 35 | def close(self) -> None: 36 | if self._zmq: 37 | self._zmq.close() 38 | else: 39 | raise ValueError("ZMQ not initialized, make sure ZMQTransport.setup() is called.") 40 | 41 | def __reduce__(self): 42 | return ZMQTransport, (self.backend_address, self.frontend_address) 43 | -------------------------------------------------------------------------------- /src/litserve/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import asyncio 15 | import dataclasses 16 | import logging 17 | import os 18 | import pdb 19 | import pickle 20 | import sys 21 | import uuid 22 | from contextlib import contextmanager 23 | from enum import Enum 24 | from typing import TYPE_CHECKING, Any, AsyncIterator, TextIO, Union 25 | 26 | from fastapi import HTTPException 27 | 28 | if TYPE_CHECKING: 29 | from litserve.server import LitServer 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class LitAPIStatus: 35 | OK = "OK" 36 | ERROR = "ERROR" 37 | FINISH_STREAMING = "FINISH_STREAMING" 38 | 39 | 40 | class LoopResponseType(Enum): 41 | STREAMING = "STREAMING" 42 | REGULAR = "REGULAR" 43 | 44 | 45 | class PickleableHTTPException(HTTPException): 46 | @staticmethod 47 | def from_exception(exc: HTTPException): 48 | status_code = exc.status_code 49 | detail = exc.detail 50 | return PickleableHTTPException(status_code, detail) 51 | 52 | def __reduce__(self): 53 | return (HTTPException, (self.status_code, self.detail)) 54 | 55 | 56 | def dump_exception(exception): 57 | if isinstance(exception, HTTPException): 58 | exception = PickleableHTTPException.from_exception(exception) 59 | return pickle.dumps(exception) 60 | 61 | 62 | async def azip(*async_iterables): 63 | iterators = [ait.__aiter__() for ait in async_iterables] 64 | while True: 65 | results = await asyncio.gather(*(ait.__anext__() for ait in iterators), return_exceptions=True) 66 | if any(isinstance(result, StopAsyncIteration) for result in results): 67 | break 68 | yield tuple(results) 69 | 70 | 71 | @contextmanager 72 | def wrap_litserve_start(server: "LitServer"): 73 | """Pytest utility to start the server in a context manager.""" 74 | server.app.response_queue_id = 0 75 | for lit_api in server.litapi_connector: 76 | if lit_api.spec: 77 | lit_api.spec.response_queue_id = 0 78 | 79 | manager = server._init_manager(1) 80 | processes = [] 81 | for lit_api in server.litapi_connector: 82 | processes.extend(server.launch_inference_worker(lit_api)) 83 | server._prepare_app_run(server.app) 84 | try: 85 | yield server 86 | finally: 87 | # First close the transport to signal to the response_queue_to_buffer task that it should stop 88 | server._transport.close() 89 | for p in processes: 90 | p.terminate() 91 | p.join() 92 | manager.shutdown() 93 | 94 | 95 | async def call_after_stream(streamer: AsyncIterator, callback, *args, **kwargs): 96 | try: 97 | async for item in streamer: 98 | yield item 99 | except Exception as e: 100 | logger.exception(f"Error in streamer: {e}") 101 | finally: 102 | callback(*args, **kwargs) 103 | 104 | 105 | @dataclasses.dataclass 106 | class WorkerSetupStatus: 107 | STARTING: str = "starting" 108 | READY: str = "ready" 109 | ERROR: str = "error" 110 | FINISHED: str = "finished" 111 | 112 | 113 | def _get_default_handler(stream, format): 114 | handler = logging.StreamHandler(stream) 115 | formatter = logging.Formatter(format) 116 | handler.setFormatter(formatter) 117 | return handler 118 | 119 | 120 | def configure_logging( 121 | level: Union[str, int] = logging.INFO, 122 | format: str = "%(asctime)s - %(processName)s[%(process)d] - %(name)s - %(levelname)s - %(message)s", 123 | stream: TextIO = sys.stdout, 124 | use_rich: bool = False, 125 | ): 126 | """Configure logging for the entire library with sensible defaults. 127 | 128 | Args: 129 | level (int): Logging level (default: logging.INFO) 130 | format (str): Log message format string 131 | stream (file-like): Output stream for logs 132 | use_rich (bool): Makes the logs more readable by using rich, useful for debugging. Defaults to False. 133 | 134 | """ 135 | if isinstance(level, str): 136 | level = level.upper() 137 | level = getattr(logging, level) 138 | 139 | # Clear any existing handlers to prevent duplicates 140 | library_logger = logging.getLogger("litserve") 141 | for handler in library_logger.handlers[:]: 142 | library_logger.removeHandler(handler) 143 | 144 | if use_rich: 145 | try: 146 | from rich.logging import RichHandler 147 | from rich.traceback import install 148 | 149 | install(show_locals=True) 150 | handler = RichHandler(rich_tracebacks=True, show_time=True, show_path=True) 151 | except ImportError: 152 | logger.warning("Rich is not installed, using default logging") 153 | handler = _get_default_handler(stream, format) 154 | else: 155 | handler = _get_default_handler(stream, format) 156 | 157 | # Configure library logger 158 | library_logger.setLevel(level) 159 | library_logger.addHandler(handler) 160 | library_logger.propagate = False 161 | 162 | 163 | def set_log_level(level): 164 | """Allow users to set the global logging level for the library.""" 165 | logging.getLogger("litserve").setLevel(level) 166 | 167 | 168 | def add_log_handler(handler): 169 | """Allow users to add custom log handlers. 170 | 171 | Example usage: 172 | file_handler = logging.FileHandler('library_logs.log') 173 | add_log_handler(file_handler) 174 | 175 | """ 176 | logging.getLogger("litserve").addHandler(handler) 177 | 178 | 179 | def generate_random_zmq_address(temp_dir="/tmp"): 180 | """Generate a random IPC address in the /tmp directory. 181 | 182 | Ensures the address is unique. 183 | Returns: 184 | str: A random IPC address suitable for ZeroMQ. 185 | 186 | """ 187 | unique_name = f"zmq-{uuid.uuid4().hex}.ipc" 188 | ipc_path = os.path.join(temp_dir, unique_name) 189 | return f"ipc://{ipc_path}" 190 | 191 | 192 | class ForkedPdb(pdb.Pdb): 193 | # Borrowed from - https://github.com/Lightning-AI/forked-pdb 194 | """ 195 | PDB Subclass for debugging multi-processed code 196 | Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess 197 | """ 198 | 199 | def interaction(self, *args: Any, **kwargs: Any) -> None: 200 | _stdin = sys.stdin 201 | try: 202 | sys.stdin = open("/dev/stdin") # noqa: SIM115 203 | pdb.Pdb.interaction(self, *args, **kwargs) 204 | finally: 205 | sys.stdin = _stdin 206 | 207 | 208 | def set_trace(): 209 | """Set a tracepoint in the code.""" 210 | ForkedPdb().set_trace() 211 | 212 | 213 | def set_trace_if_debug(debug_env_var="LITSERVE_DEBUG", debug_env_var_value="1"): 214 | """Set a tracepoint in the code if the environment variable LITSERVE_DEBUG is set.""" 215 | if os.environ.get(debug_env_var) == debug_env_var_value: 216 | set_trace() 217 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/e2e/default_api.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import litserve as ls 15 | 16 | if __name__ == "__main__": 17 | api = ls.test_examples.SimpleLitAPI() 18 | server = ls.LitServer(api) 19 | server.run(port=8000) 20 | -------------------------------------------------------------------------------- /tests/e2e/default_async_streaming.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import litserve as ls 15 | 16 | 17 | class AsyncAPI(ls.LitAPI): 18 | def setup(self, device) -> None: 19 | self.model = lambda x: x 20 | 21 | async def decode_request(self, request): 22 | return request["input"] 23 | 24 | async def predict(self, x): 25 | for i in range(10): 26 | yield self.model(i) 27 | 28 | async def encode_response(self, output): 29 | for out in output: 30 | yield {"output": out} 31 | 32 | 33 | if __name__ == "__main__": 34 | api = AsyncAPI(enable_async=True) 35 | server = ls.LitServer(api, stream=True) 36 | server.run(port=8000) 37 | -------------------------------------------------------------------------------- /tests/e2e/default_batched_streaming.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | 16 | import litserve as ls 17 | 18 | 19 | class SimpleStreamAPI(ls.LitAPI): 20 | def setup(self, device) -> None: 21 | self.model = lambda x, y: x * y 22 | 23 | def decode_request(self, request): 24 | return np.asarray(request["input"]) 25 | 26 | def predict(self, x): 27 | for i in range(10): 28 | yield self.model(x, i) 29 | 30 | def encode_response(self, output_stream): 31 | for outputs in output_stream: 32 | yield [{"output": output} for output in outputs] 33 | 34 | 35 | if __name__ == "__main__": 36 | server = ls.LitServer(SimpleStreamAPI(), stream=True, max_batch_size=4, batch_timeout=0.2, fast_queue=True) 37 | server.run(port=8000) 38 | -------------------------------------------------------------------------------- /tests/e2e/default_batching.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import litserve as ls 15 | 16 | if __name__ == "__main__": 17 | api = ls.test_examples.SimpleBatchedAPI() 18 | server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.05) 19 | server.run(port=8000) 20 | -------------------------------------------------------------------------------- /tests/e2e/default_openai_embedding_spec.py: -------------------------------------------------------------------------------- 1 | import litserve as ls 2 | from litserve import OpenAIEmbeddingSpec 3 | from litserve.test_examples.openai_embedding_spec_example import TestEmbedAPI 4 | 5 | if __name__ == "__main__": 6 | server = ls.LitServer(TestEmbedAPI(), spec=OpenAIEmbeddingSpec(), fast_queue=True) 7 | server.run() 8 | -------------------------------------------------------------------------------- /tests/e2e/default_openai_with_batching.py: -------------------------------------------------------------------------------- 1 | import litserve as ls 2 | from litserve.test_examples.openai_spec_example import OpenAIBatchContext 3 | 4 | if __name__ == "__main__": 5 | api = OpenAIBatchContext() 6 | server = ls.LitServer(api, spec=ls.OpenAISpec(), max_batch_size=2, batch_timeout=0.5, fast_queue=True) 7 | server.run(port=8000) 8 | -------------------------------------------------------------------------------- /tests/e2e/default_openaispec.py: -------------------------------------------------------------------------------- 1 | import litserve as ls 2 | from litserve import OpenAISpec 3 | from litserve.test_examples.openai_spec_example import TestAPI 4 | 5 | if __name__ == "__main__": 6 | server = ls.LitServer(TestAPI(), spec=OpenAISpec()) 7 | server.run() 8 | -------------------------------------------------------------------------------- /tests/e2e/default_openaispec_response_format.py: -------------------------------------------------------------------------------- 1 | import litserve as ls 2 | from litserve import OpenAISpec 3 | from litserve.test_examples.openai_spec_example import TestAPIWithStructuredOutput 4 | 5 | if __name__ == "__main__": 6 | server = ls.LitServer(TestAPIWithStructuredOutput(), spec=OpenAISpec(), fast_queue=True) 7 | server.run() 8 | -------------------------------------------------------------------------------- /tests/e2e/default_openaispec_tools.py: -------------------------------------------------------------------------------- 1 | import litserve as ls 2 | from litserve import OpenAISpec 3 | from litserve.specs.openai import ChatMessage 4 | from litserve.test_examples.openai_spec_example import TestAPI 5 | 6 | 7 | class TestAPIWithToolCalls(TestAPI): 8 | def encode_response(self, output): 9 | yield ChatMessage( 10 | role="assistant", 11 | content="", 12 | tool_calls=[ 13 | { 14 | "id": "call_abc123", 15 | "type": "function", 16 | "function": {"name": "get_current_weather", "arguments": '{\n"location": "Boston, MA"\n}'}, 17 | } 18 | ], 19 | ) 20 | 21 | 22 | if __name__ == "__main__": 23 | server = ls.LitServer(TestAPIWithToolCalls(), spec=OpenAISpec()) 24 | server.run() 25 | -------------------------------------------------------------------------------- /tests/e2e/default_single_streaming.py: -------------------------------------------------------------------------------- 1 | from litserve import LitAPI, LitServer 2 | 3 | 4 | class SimpleStreamingAPI(LitAPI): 5 | def setup(self, device) -> None: 6 | self.model = lambda x, y: x * y 7 | 8 | def decode_request(self, request): 9 | return request["input"] 10 | 11 | def predict(self, x): 12 | for i in range(1, 4): 13 | yield self.model(i, x) 14 | 15 | def encode_response(self, output_stream): 16 | for output in output_stream: 17 | yield {"output": output} 18 | 19 | 20 | if __name__ == "__main__": 21 | api = SimpleStreamingAPI() 22 | server = LitServer(api, stream=True) 23 | server.run(port=8000) 24 | -------------------------------------------------------------------------------- /tests/e2e/default_spec.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import litserve as ls 15 | from litserve.specs.openai import OpenAISpec 16 | from litserve.test_examples.openai_spec_example import TestAPI 17 | 18 | if __name__ == "__main__": 19 | server = ls.LitServer(TestAPI(), spec=OpenAISpec()) 20 | server.run(port=8000) 21 | -------------------------------------------------------------------------------- /tests/e2e/openai_embedding_with_batching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import litserve as ls 4 | 5 | 6 | class EmbeddingsAPI(ls.LitAPI): 7 | def setup(self, device): 8 | def model(x): 9 | return np.random.rand(len(x), 768) 10 | 11 | self.model = model 12 | 13 | def predict(self, inputs): 14 | return self.model(inputs) 15 | 16 | 17 | if __name__ == "__main__": 18 | api = EmbeddingsAPI(max_batch_size=10, batch_timeout=2) 19 | server = ls.LitServer(api, spec=ls.OpenAIEmbeddingSpec()) 20 | server.run(port=8000) 21 | -------------------------------------------------------------------------------- /tests/minimal_run.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import json 15 | import subprocess 16 | import time 17 | import urllib.request 18 | 19 | import psutil 20 | 21 | 22 | def main(): 23 | process = subprocess.Popen( 24 | ["python", "tests/simple_server.py"], 25 | ) 26 | print("Waiting for server to start...") 27 | time.sleep(10) 28 | try: 29 | url = "http://127.0.0.1:8000/predict" 30 | data = json.dumps({"input": 4.0}).encode("utf-8") 31 | headers = {"Content-Type": "application/json"} 32 | request = urllib.request.Request(url, data=data, headers=headers, method="POST") 33 | response = urllib.request.urlopen(request) 34 | status_code = response.getcode() 35 | assert status_code == 200 36 | except Exception: 37 | raise 38 | 39 | finally: 40 | parent = psutil.Process(process.pid) 41 | for child in parent.children(recursive=True): 42 | child.kill() 43 | process.kill() 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /tests/parity_fastapi/benchmark.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import concurrent.futures 3 | import random 4 | import time 5 | 6 | import numpy as np 7 | import requests 8 | import torch 9 | from PIL import Image 10 | 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | device = "mps" if torch.backends.mps.is_available() else device 13 | 14 | rand_mat = np.random.rand(2, 224, 224, 3) * 255 15 | Image.fromarray(rand_mat[0].astype("uint8")).convert("RGB").save("image1.jpg") 16 | Image.fromarray(rand_mat[1].astype("uint8")).convert("RGB").save("image2.jpg") 17 | 18 | SERVER_URL = "http://127.0.0.1:{}/predict" 19 | 20 | payloads = [] 21 | for file in ["image1.jpg", "image2.jpg"]: 22 | with open(file, "rb") as image_file: 23 | encoded_string = base64.b64encode(image_file.read()).decode("utf-8") 24 | payloads.append(encoded_string) 25 | 26 | 27 | def send_request(port): 28 | """Function to send a single request and measure the response time.""" 29 | url = SERVER_URL.format(port) 30 | payload = {"image_data": random.choice(payloads)} 31 | start_time = time.time() 32 | response = requests.post(url, json=payload) 33 | end_time = time.time() 34 | return end_time - start_time, response.status_code 35 | 36 | 37 | def benchmark(num_requests=100, concurrency_level=100, port=8000): 38 | """Benchmark the ML server.""" 39 | 40 | start_benchmark_time = time.time() # Start benchmark timing 41 | with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency_level) as executor: 42 | futures = [executor.submit(send_request, port) for _ in range(num_requests)] 43 | response_times = [] 44 | status_codes = [] 45 | 46 | for future in concurrent.futures.as_completed(futures): 47 | response_time, status_code = future.result() 48 | response_times.append(response_time) 49 | status_codes.append(status_code) 50 | 51 | end_benchmark_time = time.time() # End benchmark timing 52 | total_benchmark_time = end_benchmark_time - start_benchmark_time # Time in seconds 53 | 54 | # Analysis 55 | total_time = sum(response_times) # Time in seconds 56 | avg_time = total_time / num_requests # Time in seconds 57 | success_rate = status_codes.count(200) / num_requests * 100 58 | rps = num_requests / total_benchmark_time # Requests per second 59 | 60 | # Create a dictionary with the metrics 61 | metrics = { 62 | "Total Requests": num_requests, 63 | "Concurrency Level": concurrency_level, 64 | "Total Benchmark Time (seconds)": total_benchmark_time, 65 | "Average Response Time (ms)": avg_time * 1000, 66 | "Success Rate (%)": success_rate, 67 | "Requests Per Second (RPS)": rps, 68 | } 69 | 70 | # Print the metrics 71 | for key, value in metrics.items(): 72 | print(f"{key}: {value}") 73 | print("-" * 50) 74 | 75 | return metrics 76 | 77 | 78 | def run_bench(conf: dict, num_samples: int, port: int): 79 | num_requests = conf[device]["num_requests"] 80 | 81 | results = [] 82 | for _ in range(num_samples): 83 | metric = benchmark(num_requests=num_requests, concurrency_level=num_requests, port=port) 84 | results.append(metric) 85 | return results[2:] # skip warmup step 86 | -------------------------------------------------------------------------------- /tests/parity_fastapi/fastapi-server.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | 4 | import PIL 5 | import torch 6 | import torchvision 7 | from fastapi import FastAPI, HTTPException 8 | from jsonargparse import CLI 9 | from pydantic import BaseModel 10 | 11 | # Set float32 matrix multiplication precision if GPU is available and capable 12 | if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0): 13 | torch.set_float32_matmul_precision("high") 14 | 15 | app = FastAPI() 16 | 17 | 18 | class ImageData(BaseModel): 19 | image_data: str 20 | 21 | 22 | class ImageClassifierAPI: 23 | def __init__(self, device): 24 | self.device = device 25 | weights = torchvision.models.ResNet18_Weights.DEFAULT 26 | self.image_processing = weights.transforms() 27 | self.model = torchvision.models.resnet18(weights=None).eval().to(device) 28 | 29 | def process_image(self, image_data): 30 | image = base64.b64decode(image_data) 31 | pil_image = PIL.Image.open(io.BytesIO(image)).convert("RGB") 32 | processed_image = self.image_processing(pil_image) 33 | return processed_image.unsqueeze(0).to(self.device) # Add batch dimension 34 | 35 | def predict(self, x): 36 | with torch.inference_mode(): 37 | outputs = self.model(x) 38 | _, predictions = torch.max(outputs, 1) 39 | return predictions.item() 40 | 41 | 42 | device = "cuda" if torch.cuda.is_available() else "cpu" 43 | if torch.backends.mps.is_available(): 44 | device = "mps" 45 | api = ImageClassifierAPI(device) 46 | 47 | 48 | @app.get("/health") 49 | async def health(): 50 | return {"status": "ok"} 51 | 52 | 53 | @app.post("/predict") 54 | async def predict(image_data: ImageData): 55 | try: 56 | processed_image = api.process_image(image_data.image_data) 57 | prediction = api.predict(processed_image) 58 | return {"output": prediction} 59 | except Exception as e: 60 | raise HTTPException(status_code=500, detail=str(e)) 61 | 62 | 63 | def main(): 64 | import uvicorn 65 | 66 | uvicorn.run(app, host="0.0.0.0", port=8001, log_level="warning") 67 | 68 | 69 | if __name__ == "__main__": 70 | CLI(main) 71 | -------------------------------------------------------------------------------- /tests/parity_fastapi/ls-server.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import os 4 | from concurrent.futures import ThreadPoolExecutor 5 | 6 | import PIL 7 | import torch 8 | import torchvision 9 | 10 | import litserve as ls 11 | 12 | device = "cuda" if torch.cuda.is_available() else "cpu" 13 | device = "mps" if torch.backends.mps.is_available() else device 14 | conf = { 15 | "cuda": {"batch_size": 16, "workers_per_device": 2}, 16 | "cpu": {"batch_size": 8, "workers_per_device": 2}, 17 | "mps": {"batch_size": 8, "workers_per_device": 2}, 18 | } 19 | 20 | # Set float32 matrix multiplication precision if GPU is available and capable 21 | if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0): 22 | torch.set_float32_matmul_precision("high") 23 | 24 | 25 | class ImageClassifierAPI(ls.LitAPI): 26 | def setup(self, device): 27 | print(device) 28 | weights = torchvision.models.ResNet18_Weights.DEFAULT 29 | self.image_processing = weights.transforms() 30 | self.model = torchvision.models.resnet18(weights=None).eval().to(device) 31 | self.pool = ThreadPoolExecutor(os.cpu_count()) 32 | 33 | def decode_request(self, request): 34 | return request["image_data"] 35 | 36 | def batch(self, image_data_list): 37 | def process_image(image_data): 38 | image = base64.b64decode(image_data) 39 | pil_image = PIL.Image.open(io.BytesIO(image)).convert("RGB") 40 | return self.image_processing(pil_image) 41 | 42 | inputs = list(self.pool.map(process_image, image_data_list)) 43 | return torch.stack(inputs).to(self.device) 44 | 45 | def predict(self, x): 46 | with torch.inference_mode(): 47 | outputs = self.model(x) 48 | _, predictions = torch.max(outputs, 1) 49 | return predictions 50 | 51 | def unbatch(self, outputs): 52 | return outputs.tolist() 53 | 54 | def encode_response(self, output): 55 | return {"output": output} 56 | 57 | 58 | def main(batch_size: int, workers_per_device: int): 59 | print(locals()) 60 | api = ImageClassifierAPI() 61 | server = ls.LitServer( 62 | api, 63 | max_batch_size=batch_size, 64 | batch_timeout=0.01, 65 | timeout=10, 66 | workers_per_device=workers_per_device, 67 | fast_queue=True, 68 | ) 69 | server.run(port=8000, log_level="warning") 70 | 71 | 72 | if __name__ == "__main__": 73 | main(**conf[device]) 74 | -------------------------------------------------------------------------------- /tests/parity_fastapi/main.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import time 3 | from functools import wraps 4 | 5 | import psutil 6 | import requests 7 | import torch 8 | from benchmark import run_bench 9 | 10 | CONF = { 11 | "cpu": {"num_requests": 50}, 12 | "mps": {"num_requests": 50}, 13 | "cuda": {"num_requests": 100}, 14 | } 15 | 16 | device = "cuda" if torch.cuda.is_available() else "cpu" 17 | device = "mps" if torch.backends.mps.is_available() else device 18 | 19 | DIFF_FACTOR = { 20 | "cpu": 1, 21 | "cuda": 1.2, 22 | "mps": 1, 23 | } 24 | 25 | 26 | def run_python_script(filename): 27 | def decorator(test_fn): 28 | @wraps(test_fn) 29 | def wrapper(*args, **kwargs): 30 | process = subprocess.Popen( 31 | ["python", filename], 32 | ) 33 | print("Waiting for server to start...") 34 | time.sleep(10) 35 | 36 | try: 37 | return test_fn(*args, **kwargs) 38 | except Exception: 39 | raise 40 | finally: 41 | print("Killing the server") 42 | parent = psutil.Process(process.pid) 43 | for child in parent.children(recursive=True): 44 | child.kill() 45 | process.kill() 46 | 47 | return wrapper 48 | 49 | return decorator 50 | 51 | 52 | def try_health(port): 53 | for i in range(10): 54 | try: 55 | response = requests.get(f"http://127.0.0.1:{port}/health") 56 | if response.status_code == 200: 57 | return 58 | except Exception: 59 | pass 60 | 61 | 62 | @run_python_script("tests/parity_fastapi/fastapi-server.py") 63 | def run_fastapi_benchmark(num_samples): 64 | port = 8001 65 | try_health(port) 66 | return run_bench(CONF, num_samples, port) 67 | 68 | 69 | @run_python_script("tests/parity_fastapi/ls-server.py") 70 | def run_litserve_benchmark(num_samples): 71 | port = 8000 72 | try_health(port) 73 | return run_bench(CONF, num_samples, port) 74 | 75 | 76 | def mean(lst): 77 | return sum(lst) / len(lst) 78 | 79 | 80 | def main(): 81 | key = "Requests Per Second (RPS)" 82 | num_samples = 12 83 | print("Running FastAPI benchmark") 84 | fastapi_metrics = run_fastapi_benchmark(num_samples=num_samples) 85 | print("\n\n" + "=" * 50 + "\n\n") 86 | print("Running LitServe benchmark") 87 | ls_metrics = run_litserve_benchmark(num_samples=num_samples) 88 | fastapi_throughput = mean([e[key] for e in fastapi_metrics]) 89 | ls_throughput = mean([e[key] for e in ls_metrics]) 90 | factor = DIFF_FACTOR[device] 91 | msg = f"LitServe should have higher throughput than FastAPI on {device}. {ls_throughput} vs {fastapi_throughput}" 92 | assert ls_throughput > fastapi_throughput * factor, msg 93 | print(f"{ls_throughput} vs {fastapi_throughput}") 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /tests/perf_test/bert/benchmark.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import requests 5 | from tenacity import retry, stop_after_attempt 6 | from utils import benchmark 7 | 8 | # Configuration 9 | SERVER_URL = "http://0.0.0.0:8000/predict" 10 | MAX_SPEED = 390 # Nvidia 3090 11 | 12 | session = requests.Session() 13 | 14 | 15 | def get_average_throughput(num_requests=100, num_samples=10): 16 | key = "Requests Per Second (RPS)" 17 | latency_key = "Latency per Request (ms)" 18 | metric = 0 19 | latency = 0 20 | 21 | # warmup 22 | benchmark(num_requests=50, concurrency_level=10) 23 | for i in range(num_samples): 24 | bnmk = benchmark(num_requests=num_requests, concurrency_level=num_requests) 25 | metric += bnmk[key] 26 | latency += bnmk[latency_key] 27 | avg = metric / num_samples 28 | print("avg RPS:", avg) 29 | print("avg latency:", latency / num_samples) 30 | return avg 31 | 32 | 33 | @retry(stop=stop_after_attempt(10)) 34 | def main(): 35 | for i in range(10): 36 | try: 37 | resp = requests.get("http://localhost:8000/health") 38 | if resp.status_code == 200: 39 | break 40 | except requests.exceptions.ConnectionError as e: 41 | logging.error(f"Error connecting to server: {e}") 42 | time.sleep(10) 43 | 44 | rps = get_average_throughput(100, num_samples=10) 45 | assert rps >= MAX_SPEED, f"Expected RPS to be greater than {MAX_SPEED}, got {rps}" 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /tests/perf_test/bert/data.py: -------------------------------------------------------------------------------- 1 | phrases = [ 2 | "In the midst of a bustling city, amidst the constant hum of traffic and the chatter of countless conversations, " 3 | "there exists a serene park where people come to escape the chaos. Children play on the swings, their laughter " 4 | "echoing through the air, while adults stroll along the winding paths, lost in thought. The trees, tall and" 5 | " majestic, provide a canopy of shade, and the flowers bloom in a riot of colors, adding to the park's charm." 6 | " It's a place where time seems to slow down, offering a moment of peace and reflection in an otherwise hectic" 7 | " world.", 8 | "As the sun sets over the horizon, painting the sky in hues of orange, pink, and purple, a sense of calm descends" 9 | " over the landscape. The day has been long and filled with activity, but now, in this magical hour, everything " 10 | "feels different. The birds return to their nests, their evening songs a lullaby to the world. The gentle breeze " 11 | "carries the scent of blooming jasmine, and the stars begin to twinkle in the darkening sky. It's a time for " 12 | "quiet contemplation, for appreciating the beauty of nature, and for feeling a deep connection to the universe.", 13 | "On a remote island, far away from the noise and pollution of modern life, there is a hidden cove where " 14 | "crystal-clear waters lap gently against the shore. The beach, covered in soft, white sand, is a paradise for " 15 | "those seeking solitude and tranquility. Palm trees sway in the breeze, their fronds rustling softly, while " 16 | "the sun casts a warm, golden glow over everything. Here, one can forget the worries of the world and simply " 17 | "exist in the moment, surrounded by the natural beauty of the island and the soothing sounds of the ocean.", 18 | "In an ancient forest, where the trees have stood for centuries, there is a sense of timelessness that" 19 | " envelops everything. The air is cool and crisp, filled with the earthy scent of moss and fallen leaves. Sunlight" 20 | " filters through the dense canopy, creating dappled patterns on the forest floor. Birds call to one another, and" 21 | " small animals scurry through the underbrush. It's a place where one can feel the weight of history, where the " 22 | "presence of the past is almost palpable. Walking through this forest is like stepping back in time, to a world " 23 | "untouched by human hands.", 24 | "At the edge of a vast desert, where the dunes stretch out as far as the eye can see, there is a small oasis " 25 | "that offers a respite from the harsh conditions. A cluster of palm trees provides shade, and a clear, cool spring" 26 | " bubbles up from the ground, a source of life in an otherwise barren landscape. Travelers who come across this " 27 | "oasis are greeted with the sight of lush greenery and the sound of birdsong. It's a place of refuge and renewal," 28 | " where one can rest and recharge before continuing on their journey through the endless sands.", 29 | "High in the mountains, where the air is thin and the landscape is rugged, there is a hidden valley that remains" 30 | " largely untouched by human activity. The valley is a haven for wildlife, with streams that flow with clear," 31 | " cold water and meadows filled with wildflowers. The surrounding peaks, covered in snow even in the summer, " 32 | "stand as silent sentinels. It's a place where one can feel a profound sense of solitude and connection to nature." 33 | " The beauty of the valley, with its pristine environment and abundant life, is a reminder of the importance of" 34 | " preserving wild places.", 35 | "On a quiet country road, far from the bustling cities and noisy highways, there is a small farmhouse surrounded" 36 | " by fields of golden wheat. The farmhouse, with its weathered wooden walls and cozy interior, is a place of warmth" 37 | " and hospitality. The fields, swaying gently in the breeze, are a testament to the hard work and dedication of " 38 | "the farmers who tend them. In the evenings, the sky is filled with stars, and the only sounds are the chirping of" 39 | " crickets and the distant hoot of an owl. It's a place where one can find peace and simplicity.", 40 | "In a quaint village, nestled in the rolling hills of the countryside, life moves at a slower pace. The cobblestone" 41 | " streets are lined with charming cottages, each with its own garden bursting with flowers. The village " 42 | "square is the heart of the community, where residents gather to catch up on news and enjoy each other's company." 43 | " There's a timeless quality to the village, where traditions are upheld, and everyone knows their neighbors. " 44 | "It's a place where one can experience the joys of small-town living, with its close-knit community and strong " 45 | "sense of belonging.", 46 | "By the side of a tranquil lake, surrounded by dense forests and towering mountains, there is a small cabin that " 47 | "offers a perfect retreat from the hustle and bustle of everyday life. The cabin, with its rustic charm and cozy " 48 | "interior, is a place to unwind and relax. The lake, calm and mirror-like, reflects the beauty of the surrounding " 49 | "landscape, creating a sense of peace and serenity. It's a place where one can reconnect with nature, spend quiet" 50 | " moments fishing or kayaking, and enjoy the simple pleasures of life in a beautiful, natural setting.", 51 | ] 52 | -------------------------------------------------------------------------------- /tests/perf_test/bert/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to clean up server process 4 | cleanup() { 5 | pkill -f "python tests/perf_test/bert/server.py" 6 | } 7 | 8 | # Trap script exit to run cleanup 9 | trap cleanup EXIT 10 | 11 | # Start the server in the background and capture its PID 12 | python tests/perf_test/bert/server.py & 13 | SERVER_PID=$! 14 | 15 | echo "Server started with PID $SERVER_PID" 16 | 17 | # Run your benchmark script 18 | echo "Preparing to run benchmark.py..." 19 | 20 | export PYTHONPATH=$PWD && python tests/perf_test/bert/benchmark.py 21 | 22 | # Check if benchmark.py exited successfully 23 | if [ $? -ne 0 ]; then 24 | echo "benchmark.py failed to run successfully." 25 | exit 1 26 | else 27 | echo "benchmark.py ran successfully." 28 | fi 29 | -------------------------------------------------------------------------------- /tests/perf_test/bert/server.py: -------------------------------------------------------------------------------- 1 | """A BERT-Large text classification server with batching to be used for benchmarking.""" 2 | 3 | import torch 4 | from jsonargparse import CLI 5 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, BertConfig 6 | 7 | import litserve as ls 8 | 9 | # Set float32 matrix multiplication precision if GPU is available and capable 10 | if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0): 11 | torch.set_float32_matmul_precision("high") 12 | 13 | # set dtype to bfloat16 if CUDA is available 14 | dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 15 | 16 | 17 | class HuggingFaceLitAPI(ls.LitAPI): 18 | def setup(self, device): 19 | print(device) 20 | model_name = "google-bert/bert-large-uncased" 21 | config = BertConfig.from_pretrained(pretrained_model_name_or_path=model_name) 22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 23 | self.model = AutoModelForSequenceClassification.from_config(config, torch_dtype=dtype) 24 | self.model.to(device) 25 | 26 | def decode_request(self, request: dict): 27 | return request["text"] 28 | 29 | def batch(self, inputs): 30 | return self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True) 31 | 32 | def predict(self, inputs): 33 | inputs = {key: value.to(self.device) for key, value in inputs.items()} 34 | with torch.inference_mode(): 35 | outputs = self.model(**inputs) 36 | logits = outputs.logits 37 | return torch.argmax(logits, dim=1) 38 | 39 | def unbatch(self, outputs): 40 | return outputs.tolist() 41 | 42 | def encode_response(self, output): 43 | return {"label_idx": output} 44 | 45 | 46 | def main( 47 | batch_size: int = 10, 48 | batch_timeout: float = 0.01, 49 | devices: int = 2, 50 | workers_per_device=2, 51 | ): 52 | print(locals()) 53 | api = HuggingFaceLitAPI() 54 | server = ls.LitServer( 55 | api, 56 | max_batch_size=batch_size, 57 | batch_timeout=batch_timeout, 58 | workers_per_device=workers_per_device, 59 | accelerator="auto", 60 | devices=devices, 61 | timeout=200, 62 | fast_queue=True, 63 | ) 64 | server.run(log_level="warning", num_api_servers=4, generate_client_file=False) 65 | 66 | 67 | if __name__ == "__main__": 68 | CLI(main) 69 | -------------------------------------------------------------------------------- /tests/perf_test/bert/utils.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import random 3 | import time 4 | 5 | import gpustat 6 | import psutil 7 | import requests 8 | 9 | from tests.perf_test.bert.data import phrases 10 | 11 | 12 | def create_random_batch(size: int): 13 | result = [] 14 | for i in range(size): 15 | result.append(random.choice(phrases)) 16 | 17 | return result 18 | 19 | 20 | # Configuration 21 | SERVER_URL = "http://0.0.0.0:8000/predict" 22 | 23 | session = requests.Session() 24 | 25 | executor = None 26 | 27 | 28 | def send_request(): 29 | """Function to send a single request and measure the response time.""" 30 | payload = {"text": random.choice(phrases)} 31 | start_time = time.time() 32 | response = session.post(SERVER_URL, json=payload) 33 | end_time = time.time() 34 | return end_time - start_time, response.status_code 35 | 36 | 37 | def benchmark(num_requests=1000, concurrency_level=50, print_metrics=True): 38 | """Benchmark the ML server.""" 39 | global executor 40 | if executor is None: 41 | print("creating executor") 42 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=concurrency_level) 43 | 44 | if executor._max_workers < concurrency_level: 45 | print("updating executor") 46 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=concurrency_level) 47 | 48 | start_benchmark_time = time.time() # Start benchmark timing 49 | futures = [executor.submit(send_request) for _ in range(num_requests)] 50 | response_times = [] 51 | status_codes = [] 52 | 53 | for future in concurrent.futures.as_completed(futures): 54 | response_time, status_code = future.result() 55 | response_times.append(response_time) 56 | status_codes.append(status_code) 57 | 58 | end_benchmark_time = time.time() # End benchmark timing 59 | total_benchmark_time = end_benchmark_time - start_benchmark_time # Time in seconds 60 | 61 | # Analysis 62 | total_time = sum(response_times) # Time in seconds 63 | avg_time = total_time / num_requests # Time in seconds 64 | avg_latency_per_request = (total_time / num_requests) * 1000 # Convert to milliseconds 65 | success_rate = status_codes.count(200) / num_requests * 100 66 | rps = num_requests / total_benchmark_time # Requests per second 67 | 68 | # Calculate throughput per concurrent user in requests per second 69 | successful_requests = status_codes.count(200) 70 | throughput_per_user = (successful_requests / total_benchmark_time) / concurrency_level # Requests per second 71 | 72 | # Create a dictionary with the metrics 73 | metrics = { 74 | "Total Requests": num_requests, 75 | "Concurrency Level": concurrency_level, 76 | "Total Benchmark Time (seconds)": total_benchmark_time, 77 | "Average Response Time (ms)": avg_time * 1000, 78 | "Success Rate (%)": success_rate, 79 | "Requests Per Second (RPS)": rps, 80 | "Latency per Request (ms)": avg_latency_per_request, 81 | "Throughput per Concurrent User (requests/second)": throughput_per_user, 82 | } 83 | try: 84 | gpu_stats = gpustat.GPUStatCollection.new_query() 85 | metrics["GPU Utilization"] = sum([gpu.utilization for gpu in gpu_stats.gpus]) # / len(gpu_stats.gpus) 86 | except Exception: 87 | metrics["GPU Utilization"] = -1 88 | metrics["CPU Usage"] = psutil.cpu_percent(0.5) 89 | 90 | # Print the metrics 91 | if print_metrics: 92 | for key, value in metrics.items(): 93 | print(f"{key}: {value}") 94 | print("-" * 50) 95 | 96 | return metrics 97 | -------------------------------------------------------------------------------- /tests/perf_test/stream/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 1. Test server streams data very fast 3 | 4 | # Function to clean up server process 5 | cleanup() { 6 | pkill -f "python tests/perf_test/stream/stream_speed/server.py" 7 | } 8 | 9 | # Trap script exit to run cleanup 10 | trap cleanup EXIT 11 | 12 | # Start the server in the background and capture its PID 13 | python tests/perf_test/stream/stream_speed/server.py & 14 | SERVER_PID=$! 15 | 16 | echo "Server started with PID $SERVER_PID" 17 | 18 | # Run your benchmark script 19 | echo "Preparing to run benchmark.py..." 20 | 21 | export PYTHONPATH=$PWD && python tests/perf_test/stream/stream_speed/benchmark.py 22 | 23 | # Check if benchmark.py exited successfully 24 | if [ $? -ne 0 ]; then 25 | echo "benchmark.py failed to run successfully." 26 | exit 1 27 | else 28 | echo "benchmark.py ran successfully." 29 | fi 30 | -------------------------------------------------------------------------------- /tests/perf_test/stream/stream_speed/benchmark.py: -------------------------------------------------------------------------------- 1 | """Consume 10K tokens from the stream endpoint and measure the speed.""" 2 | 3 | import logging 4 | import time 5 | 6 | import requests 7 | from tenacity import retry, stop_after_attempt 8 | 9 | logger = logging.getLogger(__name__) 10 | # Configuration 11 | SERVER_URL = "http://0.0.0.0:8000/predict" 12 | TOTAL_TOKENS = 10000 13 | EXPECTED_TTFT = 0.005 # time to first token 14 | 15 | # tokens per second 16 | MAX_SPEED = 3600 # 3600 on GitHub CI, 10000 on M3 Pro 17 | 18 | session = requests.Session() 19 | 20 | 21 | def speed_test(): 22 | start = time.time() 23 | resp = session.post(SERVER_URL, stream=True, json={"input": 1}) 24 | num_tokens = 0 25 | ttft = None # time to first token 26 | for line in resp.iter_lines(): 27 | if not line: 28 | continue 29 | if ttft is None: 30 | ttft = time.time() - start 31 | print(f"Time to first token: {ttft}") 32 | assert ttft < EXPECTED_TTFT, f"Expected time to first token to be less than 0.1 seconds but got {ttft}" 33 | num_tokens += 1 34 | end = time.time() 35 | resp.raise_for_status() 36 | assert num_tokens == TOTAL_TOKENS, f"Expected {TOTAL_TOKENS} tokens, got {num_tokens}" 37 | speed = num_tokens / (end - start) 38 | return {"speed": speed, "time": end - start} 39 | 40 | 41 | @retry(stop=stop_after_attempt(10)) 42 | def main(): 43 | for i in range(10): 44 | try: 45 | resp = requests.get("http://localhost:8000/health") 46 | if resp.status_code == 200: 47 | break 48 | except requests.exceptions.ConnectionError as e: 49 | logger.error(f"Error connecting to server: {e}") 50 | time.sleep(10) 51 | data = speed_test() 52 | speed = data["speed"] 53 | print(data) 54 | assert speed >= MAX_SPEED, f"Expected streaming speed to be greater than {MAX_SPEED}, got {speed}" 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /tests/perf_test/stream/stream_speed/server.py: -------------------------------------------------------------------------------- 1 | import litserve as ls 2 | 3 | 4 | class SimpleStreamingAPI(ls.LitAPI): 5 | def setup(self, device) -> None: 6 | self.model = None 7 | 8 | def decode_request(self, request): 9 | return request["input"] 10 | 11 | def predict(self, x): 12 | yield from range(10000) 13 | 14 | def encode_response(self, output_stream): 15 | for output in output_stream: 16 | yield {"output": output} 17 | 18 | 19 | if __name__ == "__main__": 20 | api = SimpleStreamingAPI() 21 | server = ls.LitServer( 22 | api, 23 | stream=True, 24 | ) 25 | server.run(port=8000, generate_client_file=False) 26 | -------------------------------------------------------------------------------- /tests/simple_server.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from fastapi import Request, Response 15 | 16 | from litserve.api import LitAPI 17 | from litserve.server import LitServer 18 | 19 | 20 | class SimpleLitAPI(LitAPI): 21 | def setup(self, device): 22 | self.model = lambda x: x**2 23 | 24 | def decode_request(self, request: Request): 25 | return request["input"] 26 | 27 | def predict(self, x): 28 | return self.model(x) 29 | 30 | def encode_response(self, output) -> Response: 31 | return {"output": output} 32 | 33 | 34 | if __name__ == "__main__": 35 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=10) 36 | server.run() 37 | -------------------------------------------------------------------------------- /tests/simple_server_diff_port.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from litserve.api import LitAPI 15 | from litserve.server import LitServer 16 | 17 | 18 | class SimpleLitAPI(LitAPI): 19 | def setup(self, device): 20 | self.model = lambda x: x**2 21 | 22 | def decode_request(self, request): 23 | return request["input"] 24 | 25 | def predict(self, x): 26 | return self.model(x) 27 | 28 | def encode_response(self, output): 29 | return {"output": output} 30 | 31 | 32 | if __name__ == "__main__": 33 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=10) 34 | server.run(port=8080) 35 | -------------------------------------------------------------------------------- /tests/test_auth.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from fastapi import Depends, HTTPException, Request, Response 16 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 17 | from fastapi.testclient import TestClient 18 | 19 | import litserve.server 20 | from litserve import LitAPI, LitServer 21 | from litserve.utils import wrap_litserve_start 22 | 23 | 24 | class SimpleAuthedLitAPI(LitAPI): 25 | def setup(self, device): 26 | self.model = lambda x: x**2 27 | 28 | def decode_request(self, request: Request): 29 | return request["input"] 30 | 31 | def predict(self, x): 32 | return self.model(x) 33 | 34 | def encode_response(self, output) -> Response: 35 | return {"output": output} 36 | 37 | def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())): 38 | if auth.scheme != "Bearer" or auth.credentials != "1234": 39 | raise HTTPException(status_code=401, detail="Bad token") 40 | 41 | 42 | def test_authorized_custom(): 43 | server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) 44 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 45 | input = {"input": 4.0} 46 | response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input) 47 | assert response.status_code == 200 48 | 49 | 50 | def test_not_authorized_custom(): 51 | server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) 52 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 53 | input = {"input": 4.0} 54 | response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input) 55 | assert response.status_code == 401 56 | 57 | 58 | class SimpleLitAPI(LitAPI): 59 | def setup(self, device): 60 | self.model = lambda x: x**2 61 | 62 | def decode_request(self, request: Request): 63 | return request["input"] 64 | 65 | def predict(self, x): 66 | return self.model(x) 67 | 68 | def encode_response(self, output) -> Response: 69 | return {"output": output} 70 | 71 | 72 | def test_authorized_api_key(): 73 | litserve.server.LIT_SERVER_API_KEY = "abcd" 74 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) 75 | 76 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 77 | input = {"input": 4.0} 78 | response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input) 79 | assert response.status_code == 200 80 | 81 | litserve.server.LIT_SERVER_API_KEY = None 82 | 83 | 84 | def test_not_authorized_api_key(): 85 | litserve.server.LIT_SERVER_API_KEY = "abcd" 86 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) 87 | 88 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 89 | input = {"input": 4.0} 90 | response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input) 91 | assert response.status_code == 401 92 | 93 | litserve.server.LIT_SERVER_API_KEY = None 94 | -------------------------------------------------------------------------------- /tests/test_callbacks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import re 3 | import time 4 | 5 | import pytest 6 | from asgi_lifespan import LifespanManager 7 | from fastapi.testclient import TestClient 8 | from httpx import ASGITransport, AsyncClient 9 | 10 | import litserve as ls 11 | from litserve.callbacks import CallbackRunner, EventTypes 12 | from litserve.callbacks.defaults import PredictionTimeLogger 13 | from litserve.callbacks.defaults.metric_callback import RequestTracker 14 | from litserve.utils import wrap_litserve_start 15 | 16 | 17 | async def run_simple_request(server, num_requests=1): 18 | with wrap_litserve_start(server) as server: 19 | async with LifespanManager(server.app) as manager, AsyncClient( 20 | transport=ASGITransport(app=manager.app), base_url="http://test" 21 | ) as ac: 22 | responses = [ac.post("/predict", json={"input": 4.0}) for _ in range(num_requests)] 23 | responses = await asyncio.gather(*responses) 24 | for response in responses: 25 | assert response.json() == {"output": 16.0}, "Unexpected response" 26 | 27 | 28 | class SlowAPI(ls.test_examples.SimpleLitAPI): 29 | def predict(self, x): 30 | time.sleep(1) 31 | return super().predict(x) 32 | 33 | 34 | def test_callback_runner(): 35 | cb_runner = CallbackRunner() 36 | assert cb_runner._callbacks == [], "Callbacks list must be empty" 37 | 38 | cb = PredictionTimeLogger() 39 | cb_runner._add_callbacks(cb) 40 | assert cb_runner._callbacks == [cb], "Callback not added to runner" 41 | 42 | 43 | def test_callback(capfd): 44 | lit_api = ls.test_examples.SimpleLitAPI() 45 | server = ls.LitServer(lit_api, callbacks=[PredictionTimeLogger()]) 46 | 47 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 48 | response = client.post("/predict", json={"input": 4.0}) 49 | assert response.json() == {"output": 16.0} 50 | 51 | captured = capfd.readouterr() 52 | pattern = r"Prediction took \d+\.\d{2} seconds" 53 | assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}" 54 | 55 | 56 | def test_metric_logger(capfd): 57 | cb = PredictionTimeLogger() 58 | cb_runner = CallbackRunner() 59 | cb_runner._add_callbacks(cb) 60 | assert cb_runner._callbacks == [cb], "Callback not added to runner" 61 | cb_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=None) 62 | cb_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=None) 63 | 64 | captured = capfd.readouterr() 65 | pattern = r"Prediction took \d+\.\d{2} seconds" 66 | assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}" 67 | 68 | 69 | @pytest.mark.asyncio 70 | async def test_request_tracker(capfd): 71 | lit_api = SlowAPI() 72 | 73 | server = ls.LitServer(lit_api, track_requests=False, callbacks=[RequestTracker()]) 74 | await run_simple_request(server, 1) 75 | captured = capfd.readouterr() 76 | assert "Active requests: None" in captured.out, f"Expected pattern not found in output: {captured.out}" 77 | 78 | server = ls.LitServer(lit_api, track_requests=True, callbacks=[RequestTracker()]) 79 | await run_simple_request(server, 4) 80 | captured = capfd.readouterr() 81 | assert "Active requests: 4" in captured.out, f"Expected pattern not found in output: {captured.out}" 82 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | 7 | from litserve.__main__ import main 8 | from litserve.cli import _ensure_lightning_installed 9 | from litserve.cli import main as cli_main 10 | 11 | 12 | def test_dockerize_help(monkeypatch, capsys): 13 | monkeypatch.setattr("sys.argv", ["litserve", "dockerize", "--help"]) 14 | # argparse calls sys.exit() after printing help 15 | with pytest.raises(SystemExit): 16 | main() 17 | captured = capsys.readouterr() 18 | assert "usage:" in captured.out, "CLI did not print help message" 19 | assert "The path to the server file." in captured.out, "CLI did not print help message" 20 | 21 | 22 | def test_dockerize_command(monkeypatch, capsys): 23 | # Assuming you have a dummy server file for testing 24 | dummy_server_file = "dummy_server.py" 25 | with open(dummy_server_file, "w") as f: 26 | f.write("# Dummy server file for testing\n") 27 | 28 | monkeypatch.setattr("sys.argv", ["litserve", "dockerize", dummy_server_file]) 29 | main() 30 | captured = capsys.readouterr() 31 | os.remove(dummy_server_file) 32 | assert "Dockerfile created successfully" in captured.out, "CLI did not create Dockerfile" 33 | assert os.path.exists("Dockerfile"), "CLI did not create Dockerfile" 34 | 35 | 36 | @patch("importlib.util.find_spec") 37 | @patch("subprocess.check_call") 38 | def test_ensure_lightning_installed(mock_check_call, mock_find_spec): 39 | mock_find_spec.return_value = False 40 | _ensure_lightning_installed() 41 | mock_check_call.assert_called_once_with([sys.executable, "-m", "pip", "install", "-U", "lightning-sdk"]) 42 | 43 | 44 | # TODO: Remove this once we have a fix for Python 3.9 and 3.10 45 | @pytest.mark.skipif(sys.version_info[:2] in [(3, 9), (3, 10)], reason="Test fails on Python 3.9 and 3.10") 46 | @patch("importlib.util.find_spec") 47 | @patch("subprocess.check_call") 48 | @patch("builtins.__import__") 49 | def test_cli_main_lightning_not_installed(mock_import, mock_check_call, mock_find_spec): 50 | # Create a mock for the lightning_sdk module and its components 51 | mock_lightning_sdk = MagicMock() 52 | mock_lightning_sdk.cli.entrypoint.main_cli = MagicMock() 53 | 54 | # Configure __import__ to return our mock when lightning_sdk is imported 55 | def side_effect(name, *args, **kwargs): 56 | if name == "lightning_sdk.cli.entrypoint": 57 | return mock_lightning_sdk 58 | return __import__(name, *args, **kwargs) 59 | 60 | mock_import.side_effect = side_effect 61 | 62 | # Test when lightning_sdk is not installed but gets installed dynamically 63 | mock_find_spec.side_effect = [False, True] # First call returns False, second call returns True 64 | test_args = ["lightning", "run", "app", "app.py"] 65 | 66 | with patch.object(sys, "argv", test_args): 67 | cli_main() 68 | 69 | mock_check_call.assert_called_once_with([sys.executable, "-m", "pip", "install", "-U", "lightning-sdk"]) 70 | 71 | 72 | @pytest.mark.skipif(sys.version_info[:2] in [(3, 9), (3, 10)], reason="Test fails on Python 3.9 and 3.10") 73 | @patch("importlib.util.find_spec") 74 | @patch("builtins.__import__") 75 | def test_cli_main_import_error(mock_import, mock_find_spec, capsys): 76 | # Set up the mock to raise ImportError specifically for lightning_sdk import 77 | def import_mock(name, *args, **kwargs): 78 | if name == "lightning_sdk.cli.entrypoint": 79 | raise ImportError("Module not found") 80 | return __import__(name, *args, **kwargs) 81 | 82 | mock_import.side_effect = import_mock 83 | 84 | # Mock find_spec to return True so we attempt the import 85 | mock_find_spec.return_value = True 86 | test_args = ["lightning", "deploy", "api", "app.py"] 87 | 88 | with patch.object(sys, "argv", test_args): # noqa: SIM117 89 | with pytest.raises(SystemExit) as excinfo: 90 | cli_main() 91 | 92 | assert excinfo.value.code == 1 93 | captured = capsys.readouterr() 94 | assert "Error importing lightning_sdk CLI" in captured.out 95 | -------------------------------------------------------------------------------- /tests/test_compression.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from fastapi import Request, Response 16 | from fastapi.testclient import TestClient 17 | 18 | from litserve import LitAPI, LitServer 19 | from litserve.utils import wrap_litserve_start 20 | 21 | # trivially compressible content 22 | test_output = {"result": "0" * 100000} 23 | 24 | 25 | class LargeOutputLitAPI(LitAPI): 26 | def setup(self, device): 27 | pass 28 | 29 | def decode_request(self, request: Request): 30 | pass 31 | 32 | def predict(self, x): 33 | pass 34 | 35 | def encode_response(self, output) -> Response: 36 | return test_output 37 | 38 | 39 | def test_compression(): 40 | server = LitServer(LargeOutputLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) 41 | 42 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 43 | # compressed 44 | response = client.post("/predict", headers={"Accept-Encoding": "gzip"}, json={}) 45 | assert response.status_code == 200 46 | assert response.headers["Content-Encoding"] == "gzip" 47 | content_length = int(response.headers["Content-Length"]) 48 | assert 0 < content_length < 100000 49 | assert response.json() == test_output 50 | 51 | # uncompressed 52 | response = client.post("/predict", headers={"Accept-Encoding": ""}, json={}) 53 | assert response.status_code == 200 54 | assert "Content-Encoding" not in response.headers 55 | content_length = int(response.headers["Content-Length"]) 56 | assert content_length > 100000 57 | assert response.json() == test_output 58 | -------------------------------------------------------------------------------- /tests/test_connector.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from unittest.mock import patch 15 | 16 | import pytest 17 | import torch 18 | 19 | from litserve.connector import _Connector, check_cuda_with_nvidia_smi 20 | 21 | 22 | @pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU") 23 | def test_check_cuda_with_nvidia_smi(): 24 | assert check_cuda_with_nvidia_smi() == torch.cuda.device_count() 25 | 26 | 27 | @pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Non Nvidia GPU only") 28 | @patch( 29 | "litserve.connector.subprocess.check_output", 30 | return_value=b"GPU 0: NVIDIA GeForce RTX 4090 (UUID: GPU-rb438fre-0ar-9702-de35-ref4rjn34omk3 )", 31 | ) 32 | def test_check_cuda_with_nvidia_smi_mock_gpu(mock_subprocess): 33 | check_cuda_with_nvidia_smi.cache_clear() 34 | assert check_cuda_with_nvidia_smi() == 1 35 | check_cuda_with_nvidia_smi.cache_clear() 36 | 37 | 38 | @pytest.mark.parametrize( 39 | ("input_accelerator", "expected_accelerator", "expected_devices"), 40 | [ 41 | ("cpu", "cpu", 1), 42 | pytest.param( 43 | "cuda", 44 | "cuda", 45 | torch.cuda.device_count(), 46 | marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"), 47 | ), 48 | pytest.param( 49 | "gpu", 50 | "cuda", 51 | torch.cuda.device_count(), 52 | marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"), 53 | ), 54 | pytest.param( 55 | None, 56 | "cuda", 57 | torch.cuda.device_count(), 58 | marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"), 59 | ), 60 | pytest.param( 61 | "auto", 62 | "cuda", 63 | torch.cuda.device_count(), 64 | marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"), 65 | ), 66 | pytest.param( 67 | "auto", 68 | "mps", 69 | 1, 70 | marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"), 71 | ), 72 | pytest.param( 73 | "gpu", 74 | "mps", 75 | 1, 76 | marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"), 77 | ), 78 | pytest.param( 79 | "mps", 80 | "mps", 81 | 1, 82 | marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"), 83 | ), 84 | pytest.param( 85 | None, 86 | "mps", 87 | 1, 88 | marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"), 89 | ), 90 | ], 91 | ) 92 | def test_connector(input_accelerator, expected_accelerator, expected_devices): 93 | check_cuda_with_nvidia_smi.cache_clear() 94 | connector = _Connector(accelerator=input_accelerator) 95 | assert connector.accelerator == expected_accelerator, ( 96 | f"accelerator mismatch - expected: {expected_accelerator}, actual: {connector.accelerator}" 97 | ) 98 | 99 | assert connector.devices == expected_devices, ( 100 | f"devices mismatch - expected {expected_devices}, actual: {connector.devices}" 101 | ) 102 | 103 | with pytest.raises(ValueError, match="accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'"): 104 | _Connector(accelerator="SUPER_CHIP") 105 | 106 | 107 | def test__sanitize_accelerator(): 108 | assert _Connector._sanitize_accelerator(None) == "auto" 109 | assert _Connector._sanitize_accelerator("CPU") == "cpu" 110 | with pytest.raises(ValueError, match="accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'"): 111 | _Connector._sanitize_accelerator("SUPER_CHIP") 112 | -------------------------------------------------------------------------------- /tests/test_docker_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import litserve as ls 17 | from litserve import docker_builder 18 | 19 | 20 | def test_color(): 21 | assert docker_builder.color("hi", docker_builder.RED) == f"{docker_builder.RED}hi{docker_builder.RESET}" 22 | 23 | expected = f"{docker_builder.INFO} {docker_builder.RED}hi{docker_builder.RESET}" 24 | assert docker_builder.color("hi", docker_builder.RED, docker_builder.INFO) == expected 25 | 26 | 27 | EXPECTED_CONENT = f"""ARG PYTHON_VERSION=3.12 28 | FROM python:$PYTHON_VERSION-slim 29 | 30 | ####### Add your own installation commands here ####### 31 | # RUN pip install some-package 32 | # RUN wget https://path/to/some/data/or/weights 33 | # RUN apt-get update && apt-get install -y 34 | 35 | WORKDIR /app 36 | COPY . /app 37 | 38 | # Install litserve and requirements 39 | RUN pip install --no-cache-dir litserve=={ls.__version__} -r requirements.txt 40 | EXPOSE 8000 41 | CMD ["python", "/app/app.py"] 42 | """ 43 | 44 | 45 | EXPECTED_GPU_DOCKERFILE = f"""# Change CUDA and cuDNN version here 46 | FROM nvidia/cuda:12.4.1-base-ubuntu22.04 47 | ARG PYTHON_VERSION=3.12 48 | 49 | ENV DEBIAN_FRONTEND=noninteractive 50 | RUN apt-get update && apt-get install -y --no-install-recommends \\ 51 | software-properties-common \\ 52 | wget \\ 53 | && add-apt-repository ppa:deadsnakes/ppa \\ 54 | && apt-get update && apt-get install -y --no-install-recommends \\ 55 | python$PYTHON_VERSION \\ 56 | python$PYTHON_VERSION-dev \\ 57 | python$PYTHON_VERSION-venv \\ 58 | && wget https://bootstrap.pypa.io/get-pip.py -O get-pip.py \\ 59 | && python$PYTHON_VERSION get-pip.py \\ 60 | && rm get-pip.py \\ 61 | && ln -sf /usr/bin/python$PYTHON_VERSION /usr/bin/python \\ 62 | && ln -sf /usr/local/bin/pip$PYTHON_VERSION /usr/local/bin/pip \\ 63 | && python --version \\ 64 | && pip --version \\ 65 | && apt-get purge -y --auto-remove software-properties-common \\ 66 | && apt-get clean \\ 67 | && rm -rf /var/lib/apt/lists/* 68 | 69 | ####### Add your own installation commands here ####### 70 | # RUN pip install some-package 71 | # RUN wget https://path/to/some/data/or/weights 72 | # RUN apt-get update && apt-get install -y 73 | 74 | WORKDIR /app 75 | COPY . /app 76 | 77 | # Install litserve and requirements 78 | RUN pip install --no-cache-dir litserve=={ls.__version__} -r requirements.txt 79 | EXPOSE 8000 80 | CMD ["python", "/app/app.py"] 81 | """ 82 | 83 | 84 | def test_dockerize(tmp_path, monkeypatch): 85 | with open(tmp_path / "app.py", "w") as f: 86 | f.write("print('hello')") 87 | 88 | # Temporarily change the current working directory to tmp_path 89 | monkeypatch.chdir(tmp_path) 90 | 91 | with pytest.warns(UserWarning, match="Make sure to install the required packages in the Dockerfile."): 92 | docker_builder.dockerize("app.py", 8000) 93 | 94 | with open(tmp_path / "requirements.txt", "w") as f: 95 | f.write("lightning") 96 | 97 | docker_builder.dockerize("app.py", 8000) 98 | with open("Dockerfile") as f: 99 | content = f.read() 100 | assert content == EXPECTED_CONENT 101 | 102 | docker_builder.dockerize("app.py", 8000, gpu=True) 103 | with open("Dockerfile") as f: 104 | content = f.read() 105 | assert content == EXPECTED_GPU_DOCKERFILE 106 | 107 | with pytest.raises(FileNotFoundError, match="must be in the current directory"): 108 | docker_builder.dockerize("random_file_name.py", 8000) 109 | -------------------------------------------------------------------------------- /tests/test_form.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from fastapi import Request, Response 16 | from fastapi.testclient import TestClient 17 | 18 | from litserve import LitAPI, LitServer 19 | from litserve.utils import wrap_litserve_start 20 | 21 | 22 | class SimpleFileLitAPI(LitAPI): 23 | def setup(self, device): 24 | self.model = lambda x: x**2 25 | 26 | def decode_request(self, request: Request): 27 | return len(request["input"].file.read().decode("utf-8")) 28 | 29 | def predict(self, x): 30 | return self.model(x) 31 | 32 | def encode_response(self, output) -> Response: 33 | return {"output": output} 34 | 35 | 36 | def test_multipart_form_data(tmp_path): 37 | file_length = 1024 * 1024 * 100 38 | 39 | server = LitServer( 40 | SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length * 2) 41 | ) 42 | 43 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 44 | file_path = f"{tmp_path}/big_file.txt" 45 | with open(file_path, "wb") as f: 46 | f.write(bytearray([1] * file_length)) 47 | with open(file_path, "rb") as f: 48 | file = {"input": f} 49 | response = client.post("/predict", files=file) 50 | assert response.json() == {"output": file_length**2} 51 | 52 | 53 | def test_file_too_big(tmp_path): 54 | file_length = 1024 * 1024 * 100 55 | 56 | server = LitServer( 57 | SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length / 2) 58 | ) 59 | 60 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 61 | file_path = f"{tmp_path}/big_file.txt" 62 | with open(file_path, "wb") as f: 63 | f.write(bytearray([1] * file_length)) 64 | with open(file_path, "rb") as f: 65 | file = {"input": f} 66 | response = client.post("/predict", files=file) 67 | assert response.status_code == 413 68 | 69 | # spoof content-length size 70 | response = client.post("/predict", files=file, headers={"content-length": "1024"}) 71 | assert response.status_code == 413 72 | 73 | 74 | class SimpleFormLitAPI(LitAPI): 75 | def setup(self, device): 76 | self.model = lambda x: x**2 77 | 78 | def decode_request(self, request: Request): 79 | return float(request["input"]) 80 | 81 | def predict(self, x): 82 | return self.model(x) 83 | 84 | def encode_response(self, output) -> Response: 85 | return {"output": output} 86 | 87 | 88 | def test_urlencoded_form_data(): 89 | server = LitServer(SimpleFormLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) 90 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 91 | file = {"input": "4.0"} 92 | response = client.post("/predict", data=file) 93 | assert response.json() == {"output": 16.0} 94 | -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import threading 15 | import time 16 | from unittest.mock import MagicMock 17 | 18 | import pytest 19 | from fastapi.testclient import TestClient 20 | 21 | import litserve as ls 22 | from litserve.loggers import Logger, _LoggerConnector 23 | from litserve.utils import wrap_litserve_start 24 | 25 | 26 | class TestLogger(Logger): 27 | def process(self, key, value): 28 | self.processed_data = (key, value) 29 | 30 | 31 | @pytest.fixture 32 | def mock_lit_server(): 33 | mock_server = MagicMock() 34 | mock_server.log_queue.get = MagicMock(return_value=("test_key", "test_value")) 35 | return mock_server 36 | 37 | 38 | @pytest.fixture 39 | def test_logger(): 40 | return TestLogger() 41 | 42 | 43 | @pytest.fixture 44 | def logger_connector(mock_lit_server, test_logger): 45 | return _LoggerConnector(mock_lit_server, [test_logger]) 46 | 47 | 48 | def test_logger_mount(test_logger): 49 | mock_app = MagicMock() 50 | test_logger.mount("/test", mock_app) 51 | assert test_logger._config["mount"]["path"] == "/test" 52 | assert test_logger._config["mount"]["app"] == mock_app 53 | 54 | 55 | def test_connector_add_logger(logger_connector): 56 | new_logger = TestLogger() 57 | logger_connector.add_logger(new_logger) 58 | assert new_logger in logger_connector._loggers 59 | 60 | 61 | def test_connector_mount(mock_lit_server, test_logger, logger_connector): 62 | mock_app = MagicMock() 63 | test_logger.mount("/test", mock_app) 64 | logger_connector.add_logger(test_logger) 65 | mock_lit_server.app.mount.assert_called_with("/test", mock_app) 66 | 67 | 68 | def test_invalid_loggers(): 69 | _LoggerConnector(None, TestLogger()) 70 | with pytest.raises(ValueError, match="Logger must be an instance of litserve.Logger"): 71 | _ = _LoggerConnector(None, [MagicMock()]) 72 | 73 | with pytest.raises(ValueError, match="loggers must be a list or an instance of litserve.Logger"): 74 | _ = _LoggerConnector(None, MagicMock()) 75 | 76 | 77 | class LoggerAPI(ls.test_examples.SimpleLitAPI): 78 | def predict(self, input): 79 | result = super().predict(input) 80 | for i in range(1, 5): 81 | self.log("time", i * 0.1) 82 | return result 83 | 84 | 85 | def test_server_wo_logger(): 86 | api = LoggerAPI() 87 | server = ls.LitServer(api) 88 | 89 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 90 | response = client.post("/predict", json={"input": 4.0}) 91 | assert response.json() == {"output": 16.0} 92 | 93 | 94 | class FileLogger(ls.Logger): 95 | def __init__(self, path="test_logger_temp.txt"): 96 | super().__init__() 97 | self.path = path 98 | 99 | def process(self, key, value): 100 | with open(self.path, "a+") as f: 101 | f.write(f"{key}: {value:.1f}\n") 102 | 103 | 104 | def test_logger_with_api(tmpdir): 105 | path = str(tmpdir / "test_logger_temp.txt") 106 | api = LoggerAPI() 107 | server = ls.LitServer(api, loggers=[FileLogger(path)]) 108 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 109 | response = client.post("/predict", json={"input": 4.0}) 110 | assert response.json() == {"output": 16.0} 111 | # Wait for FileLogger to write to file 112 | time.sleep(0.5) 113 | with open(path) as f: 114 | data = f.readlines() 115 | assert data == [ 116 | "time: 0.1\n", 117 | "time: 0.2\n", 118 | "time: 0.3\n", 119 | "time: 0.4\n", 120 | ], f"Expected metric not found in logger file {data}" 121 | 122 | 123 | class PredictionTimeLogger(ls.Callback): 124 | def on_after_predict(self, lit_api): 125 | for i in range(1, 5): 126 | lit_api.log("time", i * 0.1) 127 | 128 | 129 | def test_logger_with_callback(tmp_path): 130 | path = str(tmp_path / "test_logger_temp.txt") 131 | api = ls.test_examples.SimpleLitAPI() 132 | server = ls.LitServer(api, loggers=[FileLogger(path)], callbacks=[PredictionTimeLogger()]) 133 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 134 | response = client.post("/predict", json={"input": 4.0}) 135 | assert response.json() == {"output": 16.0} 136 | # Wait for FileLogger to write to file 137 | time.sleep(0.5) 138 | with open(path) as f: 139 | data = f.readlines() 140 | assert data == [ 141 | "time: 0.1\n", 142 | "time: 0.2\n", 143 | "time: 0.3\n", 144 | "time: 0.4\n", 145 | ], f"Expected metric not found in logger file {data}" 146 | 147 | 148 | class NonPickleableLogger(ls.Logger): 149 | # This is a logger that contains a non-picklable resource 150 | def __init__(self, *args, **kwargs): 151 | super().__init__(*args, **kwargs) 152 | self._lock = threading.Lock() # Non-picklable resource 153 | 154 | def process(self, key, value): 155 | with self._lock: 156 | print(f"Logged {key}: {value}", flush=True) 157 | 158 | 159 | class PickleTestAPI(ls.test_examples.SimpleLitAPI): 160 | def predict(self, x): 161 | self.log("my-key", x) 162 | return super().predict(x) 163 | 164 | 165 | def test_pickle_safety(capfd): 166 | api = PickleTestAPI() 167 | server = ls.LitServer(api, loggers=NonPickleableLogger()) 168 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 169 | response = client.post("/predict", json={"input": 4.0}) 170 | assert response.json() == {"output": 16.0} 171 | time.sleep(0.5) 172 | captured = capfd.readouterr() 173 | assert "Logged my-key: 4.0" in captured.out, f"Expected log not found in captured output {captured}" 174 | -------------------------------------------------------------------------------- /tests/test_logging.py: -------------------------------------------------------------------------------- 1 | import io 2 | import logging 3 | 4 | import pytest 5 | 6 | from litserve.utils import add_log_handler, configure_logging, set_log_level 7 | 8 | 9 | @pytest.fixture 10 | def log_stream(): 11 | return io.StringIO() 12 | 13 | 14 | def test_configure_logging(log_stream): 15 | # Configure logging with test stream 16 | configure_logging(level=logging.DEBUG, stream=log_stream) 17 | 18 | # Get logger and log a test message 19 | logger = logging.getLogger("litserve") 20 | test_message = "Test debug message" 21 | logger.debug(test_message) 22 | 23 | # Verify log output 24 | log_contents = log_stream.getvalue() 25 | assert test_message in log_contents 26 | assert "DEBUG" in log_contents 27 | assert logger.propagate is False 28 | 29 | 30 | def test_set_log_level(): 31 | # Set log level to WARNING 32 | set_log_level(logging.WARNING) 33 | 34 | # Verify logger level 35 | logger = logging.getLogger("litserve") 36 | assert logger.level == logging.WARNING 37 | 38 | 39 | def test_add_log_handler(): 40 | # Create and add a custom handler 41 | stream = io.StringIO() 42 | custom_handler = logging.StreamHandler(stream) 43 | add_log_handler(custom_handler) 44 | 45 | # Verify handler is added 46 | logger = logging.getLogger("litserve") 47 | assert custom_handler in logger.handlers 48 | 49 | # Test the handler works 50 | test_message = "Test handler message" 51 | logger.info(test_message) 52 | assert test_message in stream.getvalue() 53 | 54 | 55 | @pytest.fixture(autouse=True) 56 | def cleanup_logger(): 57 | yield 58 | logger = logging.getLogger("litserve") 59 | logger.handlers.clear() 60 | logger.setLevel(logging.INFO) 61 | -------------------------------------------------------------------------------- /tests/test_middlewares.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | from fastapi.testclient import TestClient 16 | from starlette.middleware.base import BaseHTTPMiddleware 17 | from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware 18 | from starlette.middleware.trustedhost import TrustedHostMiddleware 19 | from starlette.types import ASGIApp 20 | 21 | import litserve as ls 22 | from litserve.utils import wrap_litserve_start 23 | 24 | 25 | class RequestIdMiddleware(BaseHTTPMiddleware): 26 | def __init__(self, app: ASGIApp, length: int) -> None: 27 | self.app = app 28 | self.length = length 29 | super().__init__(app) 30 | 31 | async def dispatch(self, request, call_next): 32 | response = await call_next(request) 33 | response.headers["X-Request-Id"] = "0" * self.length 34 | return response 35 | 36 | 37 | def test_custom_middleware(): 38 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=[(RequestIdMiddleware, {"length": 5})]) 39 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 40 | response = client.post("/predict", json={"input": 4.0}) 41 | assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}" 42 | assert response.json() == {"output": 16.0}, "server didn't return expected output" 43 | assert response.headers["X-Request-Id"] == "00000" 44 | 45 | 46 | def test_starlette_middlewares(): 47 | middlewares = [ 48 | ( 49 | TrustedHostMiddleware, 50 | { 51 | "allowed_hosts": ["localhost", "127.0.0.1"], 52 | }, 53 | ), 54 | HTTPSRedirectMiddleware, 55 | ] 56 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=middlewares) 57 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 58 | response = client.post("/predict", json={"input": 4.0}, headers={"Host": "localhost"}) 59 | assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}" 60 | assert response.json() == {"output": 16.0}, "server didn't return expected output" 61 | 62 | response = client.post("/predict", json={"input": 4.0}, headers={"Host": "not-trusted-host"}) 63 | assert response.status_code == 400, f"Expected response to be 400 but got {response.status_code}" 64 | 65 | 66 | def test_middlewares_inputs(): 67 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=[]) 68 | assert len(server.middlewares) == 1, "Default middleware should be present" 69 | 70 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=[], max_payload_size=1000) 71 | assert len(server.middlewares) == 2, "Default middleware should be present" 72 | 73 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=None) 74 | assert len(server.middlewares) == 1, "Default middleware should be present" 75 | 76 | with pytest.raises(ValueError, match="middlewares must be a list of tuples"): 77 | ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5})) 78 | -------------------------------------------------------------------------------- /tests/test_multiple_endpoints.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from asgi_lifespan import LifespanManager 3 | from httpx import ASGITransport, AsyncClient 4 | 5 | import litserve as ls 6 | from litserve.utils import wrap_litserve_start 7 | 8 | 9 | class InferencePipeline(ls.LitAPI): 10 | def __init__(self, name=None, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.name = name 13 | 14 | def setup(self, device): 15 | self.model = lambda x: x**2 16 | 17 | def decode_request(self, request): 18 | return request["input"] 19 | 20 | def predict(self, x): 21 | return self.model(x) 22 | 23 | def encode_response(self, output): 24 | return {"output": output, "name": self.name} 25 | 26 | 27 | @pytest.mark.asyncio 28 | async def test_multiple_endpoints(): 29 | api1 = InferencePipeline(name="api1", api_path="/api1") 30 | api2 = InferencePipeline(name="api2", api_path="/api2") 31 | server = ls.LitServer([api1, api2]) 32 | 33 | with wrap_litserve_start(server) as server: 34 | async with LifespanManager(server.app) as manager, AsyncClient( 35 | transport=ASGITransport(app=manager.app), base_url="http://test" 36 | ) as ac: 37 | resp = await ac.post("/api1", json={"input": 2.0}, timeout=10) 38 | assert resp.status_code == 200, "Server response should be 200 (OK)" 39 | assert resp.json()["output"] == 4.0, "output from Identity server must be same as input" 40 | assert resp.json()["name"] == "api1", "name from Identity server must be same as input" 41 | 42 | resp = await ac.post("/api2", json={"input": 5.0}, timeout=10) 43 | assert resp.status_code == 200, "Server response should be 200 (OK)" 44 | assert resp.json()["output"] == 25.0, "output from Identity server must be same as input" 45 | assert resp.json()["name"] == "api2", "name from Identity server must be same as input" 46 | 47 | 48 | def test_multiple_endpoints_with_same_path(): 49 | api1 = InferencePipeline(name="api1", api_path="/api1") 50 | api2 = InferencePipeline(name="api2", api_path="/api1") 51 | with pytest.raises(ValueError, match="api_path /api1 is already in use by"): 52 | ls.LitServer([api1, api2]) 53 | 54 | 55 | def test_reserved_paths(): 56 | api1 = InferencePipeline(name="api1", api_path="/health") 57 | api2 = InferencePipeline(name="api2", api_path="/info") 58 | with pytest.raises(ValueError, match="api_path /health is already in use by LitServe healthcheck"): 59 | ls.LitServer([api1, api2]) 60 | -------------------------------------------------------------------------------- /tests/test_pydantic.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from fastapi.testclient import TestClient 15 | from pydantic import BaseModel 16 | 17 | from litserve import LitAPI, LitServer 18 | from litserve.utils import wrap_litserve_start 19 | 20 | 21 | class PredictRequest(BaseModel): 22 | input: float 23 | 24 | 25 | class PredictResponse(BaseModel): 26 | output: float 27 | 28 | 29 | class SimpleLitAPI(LitAPI): 30 | def setup(self, device): 31 | self.model = lambda x: x**2 32 | 33 | def decode_request(self, request: PredictRequest) -> float: 34 | return request.input 35 | 36 | def predict(self, x): 37 | return self.model(x) 38 | 39 | def encode_response(self, output: float) -> PredictResponse: 40 | return PredictResponse(output=output) 41 | 42 | 43 | def test_pydantic(): 44 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=5) 45 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 46 | response = client.post("/predict", json={"input": 4.0}) 47 | assert response.json() == {"output": 16.0} 48 | -------------------------------------------------------------------------------- /tests/test_readme.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Code extraction adapted from https://github.com/tassaron/get_code_from_markdown 15 | import re 16 | import selectors 17 | import subprocess 18 | import sys 19 | import time 20 | from typing import List 21 | 22 | import pytest 23 | from tqdm import tqdm 24 | 25 | uvicorn_msg = "Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)" 26 | 27 | 28 | def extract_code_blocks(lines: List[str]) -> List[str]: 29 | language = "python" 30 | regex = re.compile( 31 | r"(?P^```(?P(\w|-)+)\n)(?P.*?\n)(?P```)", 32 | re.DOTALL | re.MULTILINE, 33 | ) 34 | blocks = [(match.group("block_language"), match.group("code")) for match in regex.finditer("".join(lines))] 35 | return [block for block_language, block in blocks if block_language == language] 36 | 37 | 38 | def get_code_blocks(file: str) -> List[str]: 39 | with open(file) as f: 40 | lines = list(f) 41 | return extract_code_blocks(lines) 42 | 43 | 44 | def get_extra_time(content: str) -> int: 45 | if "torch" in content or "transformers" in content: 46 | return 5 47 | 48 | return 0 49 | 50 | 51 | def run_script_with_timeout(file, timeout, extra_time, killall): 52 | sel = selectors.DefaultSelector() 53 | try: 54 | process = subprocess.Popen( 55 | ["python", str(file)], 56 | stdout=subprocess.PIPE, 57 | stderr=subprocess.PIPE, 58 | bufsize=1, # Line-buffered 59 | universal_newlines=True, # Decode bytes to string 60 | ) 61 | 62 | stdout_lines = [] 63 | stderr_lines = [] 64 | end_time = time.time() + timeout + extra_time 65 | 66 | sel.register(process.stdout, selectors.EVENT_READ) 67 | sel.register(process.stderr, selectors.EVENT_READ) 68 | 69 | while True: 70 | timeout_remaining = end_time - time.time() 71 | if timeout_remaining <= 0: 72 | killall(process) 73 | break 74 | 75 | events = sel.select(timeout=timeout_remaining) 76 | for key, _ in events: 77 | if key.fileobj is process.stdout: 78 | line = process.stdout.readline() 79 | if line: 80 | stdout_lines.append(line) 81 | elif key.fileobj is process.stderr: 82 | line = process.stderr.readline() 83 | if line: 84 | stderr_lines.append(line) 85 | 86 | if process.poll() is not None: 87 | break 88 | 89 | output = "".join(stdout_lines) 90 | errors = "".join(stderr_lines) 91 | 92 | # Get the return code of the process 93 | returncode = process.returncode 94 | 95 | except Exception as e: 96 | output = "" 97 | errors = str(e) 98 | returncode = -1 # Indicate failure in running the process 99 | 100 | return returncode, output, errors 101 | 102 | 103 | @pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows CI is slow and this test is just a sanity check.") 104 | def test_readme(tmp_path, killall): 105 | d = tmp_path / "readme_codes" 106 | d.mkdir(exist_ok=True) 107 | code_blocks = get_code_blocks("README.md") 108 | assert len(code_blocks) > 0, "No code block found in README.md" 109 | 110 | for i, code in enumerate(tqdm(code_blocks)): 111 | file = d / f"{i}.py" 112 | file.write_text(code) 113 | extra_time = get_extra_time(code) 114 | 115 | returncode, stdout, stderr = run_script_with_timeout(file, timeout=5, extra_time=extra_time, killall=killall) 116 | 117 | if "server.run" in code: 118 | assert uvicorn_msg in stderr, f"Expected to run uvicorn server.\nCode:\n {code}\n\nCode output: {stderr}" 119 | elif "requests.post" in code: 120 | assert "ConnectionError" in stderr, ( 121 | f"Client examples should fail with a ConnectionError because there is no server running.\nCode:\n{code}" 122 | ) 123 | else: 124 | assert returncode == 0, ( 125 | f"Code exited with {returncode}.\n" 126 | f"Error: {stderr}\n" 127 | f"Please check the code for correctness:\n```\n{code}\n```" 128 | ) 129 | -------------------------------------------------------------------------------- /tests/test_request_handlers.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import json 15 | from queue import Queue 16 | from unittest import mock 17 | from unittest.mock import AsyncMock, MagicMock, patch 18 | 19 | import pytest 20 | from fastapi import Request 21 | 22 | from litserve.server import BaseRequestHandler, RegularRequestHandler 23 | from litserve.test_examples import SimpleLitAPI 24 | from litserve.utils import LitAPIStatus 25 | 26 | 27 | @pytest.fixture 28 | def mock_lit_api(): 29 | return SimpleLitAPI() 30 | 31 | 32 | class MockServer: 33 | def __init__(self, lit_api): 34 | self.lit_api = lit_api 35 | self.response_buffer = {} 36 | self.request_queue = Queue() 37 | self._callback_runner = mock.MagicMock() 38 | self.app = mock.MagicMock() 39 | self.app.response_queue_id = 0 40 | self.active_requests = 0 41 | 42 | def _get_request_queue(self, api_path): 43 | return self.request_queue 44 | 45 | 46 | class MockRequest: 47 | """Mock FastAPI Request object for testing.""" 48 | 49 | def __init__(self, json_data=None, form_data=None, content_type="application/json"): 50 | self._json_data = json_data or {} 51 | self._form_data = form_data or {} 52 | self.headers = {"Content-Type": content_type} 53 | 54 | async def json(self): 55 | if self._json_data is None: 56 | raise json.JSONDecodeError("Invalid JSON", "", 0) 57 | return self._json_data 58 | 59 | async def form(self): 60 | return self._form_data 61 | 62 | 63 | class TestRequestHandler(BaseRequestHandler): 64 | def __init__(self, lit_api, server): 65 | super().__init__(lit_api, server) 66 | self.litapi_request_queues = {"/predict": Queue()} 67 | 68 | async def handle_request(self, request, request_type): 69 | payload = await self._prepare_request(request, request_type) 70 | uid, response_queue_id = await self._submit_request(payload) 71 | return response_queue_id 72 | 73 | 74 | @pytest.mark.asyncio 75 | async def test_request_handler(mock_lit_api): 76 | mock_server = MockServer(mock_lit_api) 77 | handler = TestRequestHandler(mock_lit_api, mock_server) 78 | mock_request = MockRequest() 79 | response_queue_id = await handler.handle_request(mock_request, Request) 80 | assert response_queue_id == 0 81 | 82 | 83 | @pytest.mark.asyncio 84 | @patch("litserve.server.asyncio.Event") 85 | async def test_request_handler_streaming(mock_event, mock_lit_api): 86 | mock_event.return_value = AsyncMock() 87 | mock_server = MockServer(mock_lit_api) 88 | mock_request = MockRequest() 89 | mock_server.response_buffer = MagicMock() 90 | mock_server.response_buffer.pop.return_value = ("test-response", LitAPIStatus.OK) 91 | handler = RegularRequestHandler(mock_lit_api, mock_server) 92 | response = await handler.handle_request(mock_request, Request) 93 | assert mock_server.request_queue.qsize() == 1 94 | assert response == "test-response" 95 | -------------------------------------------------------------------------------- /tests/test_schema.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import base64 15 | import io 16 | import os 17 | 18 | import numpy as np 19 | from fastapi.testclient import TestClient 20 | from PIL import Image 21 | 22 | import litserve as ls 23 | from litserve.schema.image import ImageInput, ImageOutput 24 | from litserve.utils import wrap_litserve_start 25 | 26 | 27 | class ImageAPI(ls.LitAPI): 28 | def setup(self, device): 29 | self.model = lambda x: np.array(x) * 2 30 | 31 | def decode_request(self, request: ImageInput): 32 | return request.get_image() 33 | 34 | def predict(self, x): 35 | return self.model(x) 36 | 37 | def encode_response(self, numpy_image) -> ImageOutput: 38 | output = Image.fromarray(np.uint8(numpy_image)).convert("RGB") 39 | return ImageOutput(image=output) 40 | 41 | 42 | def test_image_input_output(tmpdir): 43 | path = os.path.join(tmpdir, "test.png") 44 | server = ls.LitServer(ImageAPI(), accelerator="cpu", devices=1, workers_per_device=1) 45 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 46 | Image.new("RGB", (32, 32)).save(path) 47 | with open(path, "rb") as image_file: 48 | encoded_string = base64.b64encode(image_file.read()).decode("utf-8") 49 | response = client.post("/predict", json={"image_data": encoded_string}) 50 | 51 | assert response.status_code == 200, f"Unexpected status code: {response.status_code}" 52 | image_data = response.json()["image"] 53 | image = Image.open(io.BytesIO(base64.b64decode(image_data))) 54 | assert image.size == (32, 32), f"Unexpected image size: {image.size}" 55 | 56 | 57 | class MultiImageInputModel(ImageInput): 58 | image_0: str 59 | image_1: str 60 | image_2: str 61 | 62 | 63 | class MultiImageInputAPI(ImageAPI): 64 | def decode_request(self, request: MultiImageInputModel): 65 | images = [request.get_image(f"image_{i}") for i in range(3)] 66 | for image in images: 67 | assert isinstance(image, Image.Image) 68 | return images[0] 69 | 70 | 71 | def test_multiple_image_input(tmpdir): 72 | path = os.path.join(tmpdir, "test.png") 73 | server = ls.LitServer(MultiImageInputAPI(), accelerator="cpu", devices=1, workers_per_device=1) 74 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 75 | data = {} 76 | for i in range(3): 77 | Image.new("RGB", (32, 32)).save(path) 78 | with open(path, "rb") as image_file: 79 | encoded_string = base64.b64encode(image_file.read()).decode("utf-8") 80 | data[f"image_{i}"] = encoded_string 81 | response = client.post("/predict", json=data) 82 | 83 | assert response.status_code == 200, f"Unexpected status code: {response.status_code}" 84 | image_data = response.json()["image"] 85 | image = Image.open(io.BytesIO(base64.b64decode(image_data))) 86 | assert image.size == (32, 32), f"Unexpected image size: {image.size}" 87 | -------------------------------------------------------------------------------- /tests/test_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | import torch 16 | import torch.nn as nn 17 | from fastapi import Request, Response 18 | from fastapi.testclient import TestClient 19 | 20 | from litserve import LitAPI, LitServer 21 | from litserve.utils import wrap_litserve_start 22 | 23 | 24 | class Linear(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | self.linear = nn.Linear(1, 1) 28 | self.linear.weight.data.fill_(2.0) 29 | self.linear.bias.data.fill_(1.0) 30 | 31 | def forward(self, x): 32 | return self.linear(x) 33 | 34 | 35 | class SimpleLitAPI(LitAPI): 36 | def setup(self, device): 37 | self.model = Linear().to(device) 38 | self.device = device 39 | 40 | def decode_request(self, request: Request): 41 | content = request["input"] 42 | return torch.tensor([content], device=self.device) 43 | 44 | def predict(self, x): 45 | return self.model(x[None, :]) 46 | 47 | def encode_response(self, output) -> Response: 48 | return {"output": float(output)} 49 | 50 | 51 | def test_torch(): 52 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=10) 53 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 54 | response = client.post("/predict", json={"input": 4.0}) 55 | assert response.json() == {"output": 9.0} 56 | 57 | 58 | @pytest.mark.skipif(torch.cuda.device_count() == 0, reason="requires CUDA to be available") 59 | def test_torch_gpu(): 60 | server = LitServer(SimpleLitAPI(), accelerator="cuda", devices=1, timeout=10) 61 | with wrap_litserve_start(server) as server, TestClient(server.app) as client: 62 | response = client.post("/predict", json={"input": 4.0}) 63 | assert response.json() == {"output": 9.0} 64 | -------------------------------------------------------------------------------- /tests/test_transport.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import multiprocessing as mp 3 | from queue import Empty 4 | from unittest.mock import MagicMock, patch 5 | 6 | import pytest 7 | 8 | from litserve.transport.factory import TransportConfig, create_transport_from_config 9 | from litserve.transport.process_transport import MPQueueTransport 10 | 11 | 12 | class TestMPQueueTransport: 13 | @pytest.fixture 14 | def manager(self): 15 | manager = mp.Manager() 16 | yield manager 17 | manager.shutdown() 18 | 19 | @pytest.fixture 20 | def queues(self, manager): 21 | return [manager.Queue() for _ in range(2)] 22 | 23 | @pytest.fixture 24 | def transport(self, manager, queues): 25 | return MPQueueTransport(manager, queues) 26 | 27 | def test_init(self, transport, queues): 28 | """Test that the transport initializes correctly.""" 29 | assert transport._queues == queues 30 | assert transport._closed is False 31 | 32 | def test_send(self, transport, queues): 33 | test_item = {"test": "data"} 34 | consumer_id = 0 35 | 36 | transport.send(test_item, consumer_id) 37 | 38 | assert queues[consumer_id].get() == test_item 39 | 40 | def test_send_when_closed(self, transport): 41 | transport._closed = True 42 | test_item = {"test": "data"} 43 | consumer_id = 0 44 | 45 | result = transport.send(test_item, consumer_id) 46 | 47 | assert result is None 48 | 49 | @pytest.mark.asyncio 50 | async def test_areceive(self, transport, queues): 51 | test_item = {"test": "data"} 52 | consumer_id = 0 53 | queues[consumer_id].put(test_item) 54 | 55 | result = await transport.areceive(consumer_id) 56 | 57 | assert result == test_item 58 | 59 | @pytest.mark.asyncio 60 | async def test_areceive_when_closed(self, transport): 61 | transport._closed = True 62 | consumer_id = 0 63 | 64 | with pytest.raises(asyncio.CancelledError, match="Transport closed"): 65 | await transport.areceive(consumer_id) 66 | 67 | @pytest.mark.asyncio 68 | async def test_areceive_timeout(self, transport): 69 | consumer_id = 0 70 | timeout = 0.1 71 | 72 | with pytest.raises(Empty): 73 | await transport.areceive(consumer_id, timeout=timeout) 74 | 75 | @pytest.mark.asyncio 76 | async def test_areceive_cancellation(self, transport): 77 | consumer_id = 0 78 | 79 | with patch("asyncio.to_thread", side_effect=asyncio.CancelledError), pytest.raises(asyncio.CancelledError): 80 | await transport.areceive(consumer_id) 81 | 82 | def test_close(self, transport, queues): 83 | transport.close() 84 | 85 | assert transport._closed is True 86 | for queue in queues: 87 | assert queue.get() is None 88 | 89 | def test_reduce(self, transport, queues): 90 | cls, args = transport.__reduce__() 91 | 92 | assert cls == MPQueueTransport 93 | assert args == (None, queues) 94 | 95 | 96 | class TestTransportFactory: 97 | @pytest.fixture 98 | def mock_manager(self): 99 | return MagicMock() 100 | 101 | def test_create_mp_transport(self, mock_manager): 102 | config_dict = {"transport_type": "mp", "num_consumers": 2} 103 | config = TransportConfig(**config_dict) 104 | 105 | with patch("litserve.transport.factory._create_mp_transport") as mock_create: 106 | mock_create.return_value = MPQueueTransport(mock_manager, [MagicMock(), MagicMock()]) 107 | 108 | transport = create_transport_from_config(config) 109 | 110 | assert isinstance(transport, MPQueueTransport) 111 | mock_create.assert_called_once() 112 | 113 | def test_create_transport_invalid_type(self): 114 | with patch("litserve.transport.factory.TransportConfig.model_validate") as mock_validate: 115 | mock_validate.return_value = MagicMock(transport_type="invalid") 116 | 117 | with pytest.raises(ValueError, match="Invalid transport type"): 118 | create_transport_from_config(mock_validate.return_value) 119 | 120 | 121 | @pytest.mark.integration 122 | class TestTransportIntegration: 123 | """Integration tests for the transport system.""" 124 | 125 | @pytest.fixture 126 | def mock_transport(self): 127 | transport = MagicMock() 128 | transport._closed = False 129 | transport._waiting_tasks = [] 130 | 131 | transport.send = MagicMock() 132 | 133 | async def mock_areceive(consumer_id, timeout=None, block=True): 134 | current_task = asyncio.current_task() 135 | transport._waiting_tasks.append(current_task) 136 | 137 | try: 138 | if transport._closed: 139 | raise asyncio.CancelledError("Transport closed") 140 | 141 | await asyncio.sleep(10) # Long sleep to ensure we'll be cancelled 142 | 143 | # This should only be reached if not cancelled 144 | return ("test_id", {"test": "data"}) 145 | finally: 146 | # Clean up task reference 147 | if current_task in transport._waiting_tasks: 148 | transport._waiting_tasks.remove(current_task) 149 | 150 | transport.areceive = mock_areceive 151 | 152 | def mock_close(): 153 | transport._closed = True 154 | for task in transport._waiting_tasks: 155 | task.cancel() 156 | 157 | transport.close = mock_close 158 | 159 | return transport 160 | 161 | @pytest.mark.asyncio 162 | async def test_send_receive_cycle(self, mock_transport): 163 | """Test a complete send-receive cycle.""" 164 | # Arrange 165 | test_item = ("test_id", {"test": "data"}) 166 | consumer_id = 0 167 | 168 | # Act - Send 169 | mock_transport.send(test_item, consumer_id) 170 | 171 | # Act - Receive 172 | result = await mock_transport.areceive(consumer_id) 173 | 174 | # Assert 175 | assert result == test_item 176 | mock_transport.send.assert_called_once_with(test_item, consumer_id) 177 | 178 | @pytest.mark.asyncio 179 | async def test_shutdown_sequence(self, mock_transport): 180 | """Test the shutdown sequence works correctly.""" 181 | # Arrange 182 | consumer_id = 0 183 | 184 | async def receive_task(): 185 | try: 186 | await mock_transport.areceive(consumer_id) 187 | return False # Should not reach here if cancelled 188 | except asyncio.CancelledError: 189 | return True # Successfully cancelled 190 | 191 | task = asyncio.create_task(receive_task()) 192 | await asyncio.sleep(0.1) 193 | 194 | # Act 195 | mock_transport.close() 196 | result = await asyncio.wait_for(task, timeout=2.0) 197 | 198 | # Assert 199 | assert result is True 200 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import sys 5 | from unittest import mock 6 | from unittest.mock import MagicMock 7 | 8 | import pytest 9 | from fastapi import HTTPException 10 | 11 | from litserve.utils import ( 12 | call_after_stream, 13 | configure_logging, 14 | dump_exception, 15 | generate_random_zmq_address, 16 | set_trace_if_debug, 17 | ) 18 | 19 | 20 | def test_dump_exception(): 21 | e1 = dump_exception(HTTPException(status_code=404, detail="Not Found")) 22 | assert isinstance(e1, bytes) 23 | 24 | exc = HTTPException(400, "Custom Lit error") 25 | isinstance(pickle.loads(dump_exception(exc)), HTTPException) 26 | assert pickle.loads(dump_exception(exc)).detail == "Custom Lit error" 27 | assert pickle.loads(dump_exception(exc)).status_code == 400 28 | 29 | 30 | async def dummy_streamer(): 31 | for i in range(10): 32 | yield i 33 | 34 | 35 | @pytest.mark.asyncio 36 | async def test_call_after_stream(): 37 | callback = MagicMock() 38 | callback.return_value = None 39 | streamer = dummy_streamer() 40 | async for _ in call_after_stream(streamer, callback, "first_arg", random_arg="second_arg"): 41 | pass 42 | callback.assert_called() 43 | callback.assert_called_with("first_arg", random_arg="second_arg") 44 | 45 | 46 | @pytest.mark.skipif(sys.platform == "win32", reason="This test is for non-Windows platforms only.") 47 | def test_generate_random_zmq_address_non_windows(tmpdir): 48 | """Test generate_random_zmq_address on non-Windows platforms.""" 49 | 50 | temp_dir = str(tmpdir) 51 | address1 = generate_random_zmq_address(temp_dir=temp_dir) 52 | address2 = generate_random_zmq_address(temp_dir=temp_dir) 53 | 54 | assert address1.startswith("ipc://"), "Address should start with 'ipc://'" 55 | assert address2.startswith("ipc://"), "Address should start with 'ipc://'" 56 | assert address1 != address2, "Addresses should be unique" 57 | 58 | # Verify the path exists within the specified temp_dir 59 | assert os.path.commonpath([temp_dir, address1[6:]]) == temp_dir 60 | assert os.path.commonpath([temp_dir, address2[6:]]) == temp_dir 61 | 62 | 63 | def test_configure_logging(): 64 | configure_logging(use_rich=False) 65 | assert logging.getLogger("litserve").handlers[0].__class__.__name__ == "StreamHandler" 66 | 67 | 68 | def test_configure_logging_rich_not_installed(): 69 | # patch builtins.__import__ to raise ImportError 70 | with mock.patch("builtins.__import__", side_effect=ImportError): 71 | configure_logging(use_rich=True) 72 | assert logging.getLogger("litserve").handlers[0].__class__.__name__ == "StreamHandler" 73 | 74 | 75 | @mock.patch("litserve.utils.set_trace") 76 | def test_set_trace_if_debug(mock_set_trace): 77 | # mock environ 78 | with mock.patch("litserve.utils.os.environ", {"LITSERVE_DEBUG": "1"}): 79 | set_trace_if_debug() 80 | mock_set_trace.assert_called_once() 81 | 82 | 83 | @mock.patch("litserve.utils.ForkedPdb") 84 | def test_set_trace_if_debug_not_set(mock_forked_pdb): 85 | with mock.patch("litserve.utils.os.environ", {"LITSERVE_DEBUG": "0"}): 86 | set_trace_if_debug() 87 | mock_forked_pdb.assert_not_called() 88 | -------------------------------------------------------------------------------- /tests/test_zmq_queue.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pickle 3 | from queue import Empty 4 | from unittest.mock import AsyncMock, Mock, patch 5 | 6 | import pytest 7 | import zmq 8 | 9 | from litserve.transport.zmq_queue import AsyncConsumer, Broker, Producer 10 | 11 | 12 | @pytest.fixture 13 | def mock_context(): 14 | with patch("zmq.Context") as mock_ctx: 15 | socket = Mock() 16 | mock_ctx.return_value.socket.return_value = socket 17 | yield mock_ctx, socket 18 | 19 | 20 | @pytest.fixture 21 | def mock_async_context(): 22 | with patch("zmq.asyncio.Context") as mock_ctx: 23 | socket = AsyncMock() # Use AsyncMock for async methods 24 | mock_ctx.return_value.socket.return_value = socket 25 | yield mock_ctx, socket 26 | 27 | 28 | def test_broker_start_stop(mock_context): 29 | _, socket = mock_context 30 | broker = Broker(use_process=False) 31 | 32 | # Start broker 33 | broker.start() 34 | assert socket.bind.call_count == 2 # Should bind both frontend and backend 35 | 36 | # Stop broker 37 | broker.stop() 38 | assert socket.close.call_count == 2 # Should close both sockets 39 | 40 | 41 | def test_broker_error_handling(mock_context): 42 | """Test broker handles ZMQ errors.""" 43 | _, socket = mock_context 44 | socket.bind.side_effect = zmq.ZMQError("Test error") 45 | 46 | broker = Broker(use_process=False) 47 | broker.start() 48 | broker.stop() 49 | 50 | assert socket.close.called # Should clean up even on error 51 | 52 | 53 | def test_producer_send(mock_context): 54 | _, socket = mock_context 55 | producer = Producer(address="test_addr") 56 | 57 | # Test sending simple data 58 | producer.put("test_data", consumer_id=1) 59 | sent_message = socket.send.call_args[0][0] 60 | consumer_id, data = sent_message.split(b"|", 1) 61 | assert consumer_id == b"1" 62 | assert pickle.loads(data) == "test_data" 63 | 64 | # Test sending complex data 65 | complex_data = {"key": [1, 2, 3]} 66 | producer.put(complex_data, consumer_id=2) 67 | sent_message = socket.send.call_args[0][0] 68 | consumer_id, data = sent_message.split(b"|", 1) 69 | assert consumer_id == b"2" 70 | assert pickle.loads(data) == complex_data 71 | 72 | 73 | def test_producer_error_handling(mock_context): 74 | _, socket = mock_context 75 | producer = Producer(address="test_addr") 76 | 77 | # Test ZMQ error 78 | socket.send.side_effect = zmq.ZMQError("Test error") 79 | with pytest.raises(zmq.ZMQError): 80 | producer.put("data", consumer_id=1) 81 | 82 | # Test unpickleable object 83 | class Unpickleable: 84 | def __reduce__(self): 85 | raise pickle.PickleError("Can't pickle this!") 86 | 87 | with pytest.raises(pickle.PickleError): 88 | producer.put(Unpickleable(), consumer_id=1) 89 | 90 | 91 | def test_producer_wait_for_subscribers(mock_context): 92 | _, socket = mock_context 93 | producer = Producer(address="test_addr") 94 | 95 | # Test successful wait 96 | assert producer.wait_for_subscribers(timeout=0.1) 97 | assert socket.send.called 98 | 99 | # Test timeout 100 | socket.send.side_effect = zmq.ZMQError("Would block") 101 | assert not producer.wait_for_subscribers(timeout=0.1) 102 | 103 | 104 | @pytest.mark.parametrize("timeout", [1.0, None]) 105 | @pytest.mark.asyncio 106 | async def test_async_consumer(mock_async_context, timeout): 107 | _, socket = mock_async_context 108 | consumer = AsyncConsumer(consumer_id=1, address="test_addr") 109 | 110 | # Setup mock received data 111 | test_data = {"test": "data"} 112 | message = b"1|" + pickle.dumps(test_data) 113 | socket.recv.return_value = message 114 | 115 | # Test receiving 116 | received = await consumer.get(timeout=timeout) 117 | assert received == test_data 118 | 119 | # Test timeout 120 | socket.recv.side_effect = asyncio.TimeoutError() 121 | with pytest.raises(Empty): 122 | await consumer.get(timeout=timeout) 123 | 124 | 125 | @pytest.mark.asyncio 126 | async def test_async_consumer_cleanup(): 127 | with patch("zmq.asyncio.Context") as mock_ctx: 128 | socket = AsyncMock() 129 | mock_ctx.return_value.socket.return_value = socket 130 | 131 | consumer = AsyncConsumer(consumer_id=1, address="test_addr") 132 | consumer.close() 133 | 134 | assert socket.close.called 135 | assert mock_ctx.return_value.term.called 136 | 137 | 138 | def test_producer_cleanup(): 139 | with patch("zmq.Context") as mock_ctx: 140 | socket = Mock() 141 | mock_ctx.return_value.socket.return_value = socket 142 | 143 | producer = Producer(address="test_addr") 144 | producer.close() 145 | 146 | assert socket.close.called 147 | assert mock_ctx.return_value.term.called 148 | --------------------------------------------------------------------------------