├── .azure
└── gpu-tests.yml
├── .codecov.yml
├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── config.yml
│ ├── documentation.md
│ └── feature_request.md
├── PULL_REQUEST_TEMPLATE.md
├── stale.yml
└── workflows
│ ├── ci-checks.yml
│ ├── ci-minimal-dependency-check.yml
│ ├── ci-parity.yml
│ ├── ci-testing.yml
│ └── release-pypi.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── _requirements
├── perf.txt
└── test.txt
├── pyproject.toml
├── pytest.ini
├── requirements.txt
├── setup.py
├── src
└── litserve
│ ├── __about__.py
│ ├── __init__.py
│ ├── __main__.py
│ ├── api.py
│ ├── callbacks
│ ├── __init__.py
│ ├── base.py
│ └── defaults
│ │ ├── __init__.py
│ │ └── metric_callback.py
│ ├── cli.py
│ ├── connector.py
│ ├── docker_builder.py
│ ├── loggers.py
│ ├── loops
│ ├── __init__.py
│ ├── base.py
│ ├── continuous_batching_loop.py
│ ├── loops.py
│ ├── simple_loops.py
│ └── streaming_loops.py
│ ├── middlewares.py
│ ├── python_client.py
│ ├── schema
│ ├── __init__.py
│ └── image.py
│ ├── server.py
│ ├── specs
│ ├── __init__.py
│ ├── base.py
│ ├── openai.py
│ └── openai_embedding.py
│ ├── test_examples
│ ├── __init__.py
│ ├── openai_embedding_spec_example.py
│ ├── openai_spec_example.py
│ └── simple_example.py
│ ├── transport
│ ├── __init__.py
│ ├── base.py
│ ├── factory.py
│ ├── process_transport.py
│ ├── zmq_queue.py
│ └── zmq_transport.py
│ └── utils.py
└── tests
├── __init__.py
├── conftest.py
├── e2e
├── default_api.py
├── default_async_streaming.py
├── default_batched_streaming.py
├── default_batching.py
├── default_openai_embedding_spec.py
├── default_openai_with_batching.py
├── default_openaispec.py
├── default_openaispec_response_format.py
├── default_openaispec_tools.py
├── default_single_streaming.py
├── default_spec.py
├── openai_embedding_with_batching.py
└── test_e2e.py
├── minimal_run.py
├── parity_fastapi
├── benchmark.py
├── fastapi-server.py
├── ls-server.py
└── main.py
├── perf_test
├── bert
│ ├── benchmark.py
│ ├── data.py
│ ├── run_test.sh
│ ├── server.py
│ └── utils.py
└── stream
│ ├── run_test.sh
│ └── stream_speed
│ ├── benchmark.py
│ └── server.py
├── simple_server.py
├── simple_server_diff_port.py
├── test_auth.py
├── test_batch.py
├── test_callbacks.py
├── test_cli.py
├── test_compression.py
├── test_connector.py
├── test_docker_builder.py
├── test_examples.py
├── test_form.py
├── test_lit_server.py
├── test_litapi.py
├── test_logger.py
├── test_logging.py
├── test_loops.py
├── test_middlewares.py
├── test_multiple_endpoints.py
├── test_openai_embedding.py
├── test_pydantic.py
├── test_readme.py
├── test_request_handlers.py
├── test_schema.py
├── test_simple.py
├── test_specs.py
├── test_torch.py
├── test_transport.py
├── test_utils.py
└── test_zmq_queue.py
/.azure/gpu-tests.yml:
--------------------------------------------------------------------------------
1 | # Create and test a Python package on multiple PyTorch versions.
2 |
3 | trigger:
4 | tags:
5 | include:
6 | - "*"
7 | branches:
8 | include:
9 | - main
10 | - release/*
11 | - refs/tags/*
12 | pr:
13 | - main
14 | - release/*
15 |
16 | jobs:
17 | - job: serve_GPU
18 | # how long to run the job before automatically cancelling
19 | timeoutInMinutes: "20"
20 | # how much time to give 'run always even if cancelled tasks' before stopping them
21 | cancelTimeoutInMinutes: "2"
22 |
23 | pool: "lit-rtx-3090"
24 |
25 | variables:
26 | DEVICES: $( python -c 'name = "$(Agent.Name)" ; gpus = name.split("_")[-1] if "_" in name else "0,1"; print(gpus)' )
27 | # these two caches assume to run repetitively on the same set of machines
28 | # see: https://github.com/microsoft/azure-pipelines-agent/issues/4113#issuecomment-1439241481
29 | TORCH_HOME: "/var/tmp/torch"
30 | PIP_CACHE_DIR: "/var/tmp/pip"
31 |
32 | container:
33 | image: "pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime"
34 | options: "--gpus=all --shm-size=8g -v /var/tmp:/var/tmp"
35 |
36 | workspace:
37 | clean: all
38 |
39 | steps:
40 | - bash: |
41 | echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)"
42 | CUDA_version=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p')
43 | CUDA_version_mm="${CUDA_version//'.'/''}"
44 | echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$CUDA_version_mm"
45 | echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${CUDA_version_mm}/torch_stable.html"
46 | displayName: "set Env. vars"
47 |
48 | - bash: |
49 | whoami && id
50 | lspci | egrep 'VGA|3D'
51 | whereis nvidia
52 | nvidia-smi
53 | echo $CUDA_VISIBLE_DEVICES
54 | echo $TORCH_URL
55 | python --version
56 | pip --version
57 | pip cache dir
58 | pip list
59 | displayName: "Image info & NVIDIA"
60 |
61 | - bash: |
62 | pip install . -U --prefer-binary \
63 | -r ./_requirements/test.txt --find-links=${TORCH_URL}
64 | displayName: "Install environment"
65 |
66 | - bash: |
67 | pip list
68 | python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'found GPUs: {mgpu}'"
69 | displayName: "Sanity check"
70 |
71 | - bash: |
72 | pip install -q py-tree
73 | py-tree /var/tmp/torch
74 | displayName: "Show caches"
75 |
76 | - bash: |
77 | coverage run --source litserve -m pytest src tests -v
78 | displayName: "Testing"
79 |
80 | - bash: |
81 | python -m coverage report
82 | python -m coverage xml
83 | python -m codecov --token=$(CODECOV_TOKEN) --name="GPU-coverage" \
84 | --commit=$(Build.SourceVersion) --flags=gpu,unittest --env=linux,azure
85 | ls -l
86 | displayName: "Statistics"
87 |
88 | - bash: |
89 | pip install torch torchvision -U -q --find-links=${TORCH_URL} -r _requirements/perf.txt
90 | export PYTHONPATH=$PWD && python tests/parity_fastapi/main.py
91 | displayName: "Run FastAPI parity tests"
92 |
93 | - bash: |
94 | pip install gpustat wget -U -q
95 | bash tests/perf_test/bert/run_test.sh
96 | displayName: "Run GPU perf test"
97 |
--------------------------------------------------------------------------------
/.codecov.yml:
--------------------------------------------------------------------------------
1 | # see https://docs.codecov.io/docs/codecov-yaml
2 | # Validation check:
3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate
4 |
5 | # https://docs.codecov.io/docs/codecovyml-reference
6 | codecov:
7 | bot: "codecov-io"
8 | strict_yaml_branch: "yaml-config"
9 | require_ci_to_pass: yes
10 | notify:
11 | # after_n_builds: 2
12 | wait_for_ci: yes
13 |
14 | coverage:
15 | precision: 0 # 2 = xx.xx%, 0 = xx%
16 | round: nearest # how coverage is rounded: down/up/nearest
17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green
18 | status:
19 | # https://codecov.readme.io/v1.0/docs/commit-status
20 | project:
21 | default:
22 | informational: true
23 | target: 99% # specify the target coverage for each commit status
24 | threshold: 30% # allow this little decrease on project
25 | # https://github.com/codecov/support/wiki/Filtering-Branches
26 | # branches: main
27 | if_ci_failed: error
28 | # https://github.com/codecov/support/wiki/Patch-Status
29 | patch:
30 | default:
31 | informational: true
32 | target: 50% # specify the target "X%" coverage to hit
33 | # threshold: 50% # allow this much decrease on patch
34 | changes: false
35 |
36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations
37 | github_checks:
38 | annotations: false
39 |
40 | parsers:
41 | gcov:
42 | branch_detection:
43 | conditional: true
44 | loop: true
45 | macro: false
46 | method: false
47 | javascript:
48 | enable_partials: false
49 |
50 | comment:
51 | layout: header, diff
52 | require_changes: false
53 | behavior: default # update if exists else create new
54 | # branches: *
55 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Each line is a file pattern followed by one or more owners.
2 |
3 | # These owners will be the default owners for everything in the repo. Unless a later match takes precedence,
4 | # @global-owner1 and @global-owner2 will be requested for review when someone opens a pull request.
5 | * @lantiga @aniketmaurya @ethanwharris @borda @justusschock @tchaton @ali-alshaar7 @k223kim @KaelanDt
6 |
7 | # CI/CD and configs
8 | /.github/ @borda
9 | *.yaml @borda
10 | *.yml @borda
11 |
12 | # devel requirements
13 | /_requirements/ @borda @lantiga
14 |
15 | /README.md @williamfalcon @lantiga
16 | /requirements.txt @williamfalcon
17 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug, help wanted
6 | assignees: ''
7 | ---
8 |
9 | ## 🐛 Bug
10 |
11 |
12 |
13 | ### To Reproduce
14 |
15 | Attach a [Lightning Studio](https://lightning.ai/studios) which is fully reproducible (code, dependencies, environment, etc...) to reproduce this:
16 |
17 | 1. Create a [Studio](https://lightning.ai/studios).
18 | 2. Reproduce the issue in the Studio.
19 | 3. [Publish the Studio](https://lightning.ai/docs/overview/studios/publishing#how-to-publish).
20 | 4. Paste the Studio link here.
21 |
22 |
23 |
24 | #### Code sample
25 |
26 |
28 |
29 | ### Expected behavior
30 |
31 |
32 |
33 | ### Environment
34 | If you published a Studio with your bug report, we can automatically get this information. Otherwise, please describe:
35 |
36 | - PyTorch/Jax/Tensorflow Version (e.g., 1.0):
37 | - OS (e.g., Linux):
38 | - How you installed PyTorch (`conda`, `pip`, source):
39 | - Build command you used (if compiling from source):
40 | - Python version:
41 | - CUDA/cuDNN version:
42 | - GPU models and configuration:
43 | - Any other relevant information:
44 |
45 | ### Additional context
46 |
47 |
48 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: ❓ Ask a Question
4 | url: https://www.reddit.com/r/lightningAI/
5 | about: Ask and answer Lightning related questions.
6 | - name: 💬 Chat with us
7 | url: https://discord.gg/VptPCZkGNa
8 | about: Live chat with experts, engineers, and users in our Discord community.
9 | - name: 📖 Read the documentation
10 | url: https://lightning.ai/litserve/
11 | about: Please consult the documentation before opening any issues!
12 | - name: 🙋 Get help with Lightning Studio
13 | url: https://lightning.ai
14 | about: Contact the Lightning.ai sales team.
15 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Typos and doc fixes
3 | about: Typos and doc fixes
4 | title: ''
5 | labels: documentation
6 | assignees: ''
7 | ---
8 |
9 | ## 📚 Documentation
10 |
11 | For typos and doc fixes, please go ahead and:
12 |
13 | - For a simple typo or fix, please send directly a PR (no need to create an issue)
14 | - If you are not sure about the proper solution, please describe here your finding...
15 |
16 | Thanks!
17 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 | ---
8 |
9 |
22 |
23 | ----
24 |
25 | ## 🚀 Feature
26 |
27 |
28 |
29 | ### Motivation
30 |
31 |
36 |
37 | ### Pitch
38 |
39 |
40 |
41 | ### Alternatives
42 |
43 |
44 |
45 | ### Additional context
46 |
47 |
48 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## What does this PR do?
2 |
3 | Fixes # (issue).
4 |
5 |
6 |
7 | Before submitting
8 |
9 | - [ ] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
10 | - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section?
11 | - [ ] Did you make sure to update the docs?
12 | - [ ] Did you write any new necessary tests?
13 |
14 |
15 |
16 |
29 |
30 |
31 |
32 | ## PR review
33 |
34 | Anyone in the community is free to review the PR once the tests have passed.
35 | If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.
36 |
37 | ## Did you have fun?
38 |
39 | Make sure you had fun coding 🙃
40 |
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # https://github.com/marketplace/stale
2 |
3 | # Number of days of inactivity before an issue becomes stale
4 | daysUntilStale: 60
5 | # Number of days of inactivity before a stale issue is closed
6 | daysUntilClose: 14
7 | # Issues with these labels will never be considered stale
8 | exemptLabels:
9 | - pinned
10 | - security
11 | # Label to use when marking an issue as stale
12 | staleLabel: won't fix
13 | # Comment to post when marking an issue as stale. Set to `false` to disable
14 | markComment: >
15 | This issue has been automatically marked as stale because it has not had
16 | recent activity. It will be closed if no further activity occurs. Thank you
17 | for your contributions.
18 | # Comment to post when closing a stale issue. Set to `false` to disable
19 | closeComment: false
20 |
21 | # Set to true to ignore issues in a project (defaults to false)
22 | exemptProjects: true
23 | # Set to true to ignore issues in a milestone (defaults to false)
24 | exemptMilestones: true
25 | # Set to true to ignore issues with an assignee (defaults to false)
26 | exemptAssignees: true
27 |
--------------------------------------------------------------------------------
/.github/workflows/ci-checks.yml:
--------------------------------------------------------------------------------
1 | name: General checks
2 |
3 | on:
4 | push:
5 | branches: [main, "release/*"]
6 | pull_request:
7 | branches: [main, "release/*"]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
12 |
13 | jobs:
14 | # check-typing:
15 | # uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@main
16 | # with:
17 | # actions-ref: main
18 |
19 | check-schema:
20 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@main
21 | with:
22 | azure-dir: ""
23 |
24 | check-package:
25 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main
26 | with:
27 | actions-ref: main
28 | import-name: "litserve"
29 | artifact-name: dist-packages-${{ github.sha }}
30 | testing-matrix: |
31 | {
32 | "os": ["ubuntu-latest", "macos-latest", "windows-latest"],
33 | "python-version": ["3.10"],
34 | }
35 |
36 | # check-docs:
37 | # uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@main
38 | # with:
39 | # requirements-file: "_requirements/docs.txt"
40 |
--------------------------------------------------------------------------------
/.github/workflows/ci-minimal-dependency-check.yml:
--------------------------------------------------------------------------------
1 | name: Minimal dependency check
2 |
3 | on:
4 | push:
5 | branches: [main, "release/*"]
6 | pull_request:
7 | branches: [main, "release/*"]
8 |
9 | defaults:
10 | run:
11 | shell: bash
12 |
13 | jobs:
14 | minimal-test:
15 | runs-on: ubuntu-latest
16 |
17 | timeout-minutes: 30
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Set up Python
22 | uses: actions/setup-python@v5
23 | with:
24 | python-version: "3.11"
25 |
26 | - name: Install LitServe
27 | run: |
28 | pip --version
29 | pip install . psutil -U -q
30 | pip list
31 |
32 | - name: Tests
33 | run: python tests/minimal_run.py
34 |
--------------------------------------------------------------------------------
/.github/workflows/ci-parity.yml:
--------------------------------------------------------------------------------
1 | name: Run parity tests
2 |
3 | on:
4 | push:
5 | branches: [main, "release/*"]
6 | pull_request:
7 | branches: [main, "release/*"]
8 |
9 | defaults:
10 | run:
11 | shell: bash
12 |
13 | jobs:
14 | pytester:
15 | runs-on: ubuntu-latest
16 |
17 | timeout-minutes: 30
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Set up Python
22 | uses: actions/setup-python@v5
23 | with:
24 | python-version: "3.11"
25 |
26 | - name: Install LitServe
27 | run: |
28 | pip --version
29 | pip install . torchvision jsonargparse uvloop tenacity -U -q -r _requirements/test.txt -U -q
30 | pip list
31 |
32 | - name: Parity test
33 | run: export PYTHONPATH=$PWD && python tests/parity_fastapi/main.py
34 |
35 | - name: Streaming speed test
36 | run: bash tests/perf_test/stream/run_test.sh
37 |
--------------------------------------------------------------------------------
/.github/workflows/ci-testing.yml:
--------------------------------------------------------------------------------
1 | name: CI testing
2 |
3 | on:
4 | push:
5 | branches: [main, "release/*"]
6 | pull_request:
7 | branches: [main, "release/*"]
8 |
9 | defaults:
10 | run:
11 | shell: bash
12 |
13 | jobs:
14 | pytester:
15 | runs-on: ${{ matrix.os }}
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | os: ["ubuntu-latest"]
20 | python-version: ["3.9", "3.10", "3.11"] # todo, "3.12"
21 | include:
22 | - { os: "macos-latest", python-version: "3.12" }
23 | - { os: "windows-latest", python-version: "3.11" }
24 | - { os: "ubuntu-22.04", python-version: "3.9", requires: "oldest" }
25 |
26 | timeout-minutes: 35
27 | env:
28 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
29 |
30 | steps:
31 | - uses: actions/checkout@v4
32 | - name: Set up Python ${{ matrix.python-version }}
33 | uses: actions/setup-python@v5
34 | with:
35 | python-version: ${{ matrix.python-version }}
36 | cache: "pip"
37 |
38 | - name: Set min. dependencies
39 | if: matrix.requires == 'oldest'
40 | run: |
41 | pip install 'lightning-utilities[cli]'
42 | python -m lightning_utilities.cli requirements set-oldest --req_files='["requirements.txt", "_requirements/test.txt"]'
43 |
44 | - name: Install package & dependencies
45 | run: |
46 | pip --version
47 | pip install -e '.[test]' -U -q --find-links $TORCH_URL
48 | pip list
49 |
50 | - name: Tests
51 | timeout-minutes: 10
52 | run: |
53 | python -m pytest --cov=litserve src/ tests/ -v -s
54 |
55 | - name: Statistics
56 | run: |
57 | coverage report
58 | coverage xml
59 |
60 | - name: Upload coverage to Codecov
61 | uses: codecov/codecov-action@v4
62 | with:
63 | token: ${{ secrets.CODECOV_TOKEN }}
64 | file: ./coverage.xml
65 | flags: unittests
66 | env_vars: OS,PYTHON
67 | name: codecov-umbrella
68 | fail_ci_if_error: false
69 |
70 | tests-guardian:
71 | runs-on: ubuntu-latest
72 | needs: pytester
73 | if: always()
74 | steps:
75 | - run: echo "${{ needs.pytester.result }}"
76 | - name: failing...
77 | if: needs.pytester.result == 'failure'
78 | run: exit 1
79 | - name: cancelled or skipped...
80 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result)
81 | timeout-minutes: 1
82 | run: sleep 90
83 |
--------------------------------------------------------------------------------
/.github/workflows/release-pypi.yml:
--------------------------------------------------------------------------------
1 | name: PyPI Release
2 |
3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | on: # Trigger the workflow on push or pull request, but only for the main branch
5 | push:
6 | branches: [main]
7 | release:
8 | types: [published]
9 | pull_request:
10 | branches: [main]
11 | paths:
12 | - ".github/workflows/release-pypi.yml"
13 |
14 | # based on https://github.com/pypa/gh-action-pypi-publish
15 |
16 | jobs:
17 | build-package:
18 | runs-on: ubuntu-latest
19 | timeout-minutes: 10
20 | steps:
21 | - uses: actions/checkout@v4
22 | - uses: actions/setup-python@v5
23 | with:
24 | python-version: "3.x"
25 | - name: Pull reusable 🤖 actions️
26 | uses: actions/checkout@v4
27 | with:
28 | ref: main
29 | path: .cicd
30 | repository: Lightning-AI/utilities
31 |
32 | - name: Prepare build env.
33 | run: pip install -r ./.cicd/requirements/gha-package.txt
34 | - name: Create 📦 package
35 | uses: ./.cicd/.github/actions/pkg-create
36 | - uses: actions/upload-artifact@v4
37 | with:
38 | name: pypi-packages-${{ github.sha }}
39 | path: dist
40 |
41 | publish-package:
42 | needs: build-package
43 | runs-on: ubuntu-latest
44 | timeout-minutes: 5
45 | steps:
46 | - uses: actions/download-artifact@v4
47 | with:
48 | name: pypi-packages-${{ github.sha }}
49 | path: dist
50 | - run: ls -lh dist/
51 |
52 | - name: Publish distribution 📦 to PyPI
53 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
54 | uses: pypa/gh-action-pypi-publish@v1.12.3
55 | with:
56 | user: __token__
57 | password: ${{ secrets.pypi_password }}
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | *.egg-info/
29 | src.egg-info/
30 |
31 | # Lightning /research
32 | test_tube_exp/
33 | tests/tests_tt_dir/
34 | tests/save_dir
35 | default/
36 | data/
37 | test_tube_logs/
38 | test_tube_data/
39 | datasets/
40 | model_weights/
41 | tests/save_dir
42 | tests/tests_tt_dir/
43 | processed/
44 | raw/
45 |
46 | # PyInstaller
47 | # Usually these files are written by a python script from a template
48 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
49 | *.manifest
50 | *.spec
51 |
52 | # Installer logs
53 | pip-log.txt
54 | pip-delete-this-directory.txt
55 |
56 | # Unit test / coverage reports
57 | htmlcov/
58 | .tox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | .hypothesis/
66 | .pytest_cache/
67 |
68 | # Translations
69 | *.mo
70 | *.pot
71 |
72 | # Django stuff:
73 | *.log
74 | local_settings.py
75 | db.sqlite3
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # celery beat schedule file
97 | celerybeat-schedule
98 |
99 | # SageMath parsed files
100 | *.sage.py
101 |
102 | # Environments
103 | .env
104 | .venv
105 | env/
106 | venv/
107 | ENV/
108 | env.bak/
109 | venv.bak/
110 |
111 | # Spyder project settings
112 | .spyderproject
113 | .spyproject
114 |
115 | # Rope project settings
116 | .ropeproject
117 |
118 | # mkdocs documentation
119 | /site
120 |
121 |
122 | # mypy
123 | .mypy_cache/
124 |
125 | # IDEs
126 | .idea
127 | .vscode
128 |
129 | # seed project
130 | lightning_logs/
131 | MNIST
132 | .DS_Store
133 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | ci:
5 | autofix_prs: true
6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions"
7 | autoupdate_schedule: "monthly"
8 | # submodules: true
9 |
10 | repos:
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v5.0.0
13 | hooks:
14 | - id: end-of-file-fixer
15 | - id: trailing-whitespace
16 | exclude: '.*\.md$'
17 | - id: check-case-conflict
18 | - id: check-yaml
19 | - id: check-toml
20 | - id: check-json
21 | - id: check-added-large-files
22 | - id: check-docstring-first
23 | - id: detect-private-key
24 |
25 | - repo: https://github.com/codespell-project/codespell
26 | rev: v2.4.1
27 | hooks:
28 | - id: codespell
29 | additional_dependencies: [tomli]
30 | #args: ["--write-changes"]
31 |
32 | - repo: https://github.com/pre-commit/mirrors-prettier
33 | rev: v3.1.0
34 | hooks:
35 | - id: prettier
36 | files: \.(json|yml|yaml|toml)
37 | # https://prettier.io/docs/en/options.html#print-width
38 | args: ["--print-width=120"]
39 |
40 | - repo: https://github.com/PyCQA/docformatter
41 | rev: v1.7.7
42 | hooks:
43 | - id: docformatter
44 | additional_dependencies: [tomli]
45 | args: ["--in-place"]
46 |
47 | - repo: https://github.com/astral-sh/ruff-pre-commit
48 | rev: v0.11.12
49 | hooks:
50 | - id: ruff-format
51 | args: ["--preview"]
52 | - id: ruff
53 | args: ["--fix"]
54 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html
2 | graft wheelhouse
3 |
4 | recursive-exclude __pycache__ *.py[cod] *.orig
5 |
6 | # Include the README and CHANGELOG
7 | include *.md
8 | recursive-include src *.md
9 |
10 | # Include the license file
11 | include LICENSE
12 |
13 | # Exclude build configs
14 | exclude *.sh
15 | exclude *.toml
16 | exclude *.svg
17 | exclude *.yml
18 | exclude *.yaml
19 |
20 | # exclude tests from package
21 | recursive-exclude tests *
22 | recursive-exclude site *
23 | exclude tests
24 |
25 | # Exclude the documentation files
26 | recursive-exclude docs *
27 | exclude docs
28 |
29 | # Include the Requirements
30 | include requirements.txt
31 | recursive-include _requirements *.tx;t
32 |
33 | # Exclude Makefile
34 | exclude Makefile
35 |
36 | prune .git
37 | prune .github
38 | prune temp*
39 | prune test*
40 |
--------------------------------------------------------------------------------
/_requirements/perf.txt:
--------------------------------------------------------------------------------
1 | uvloop
2 | tenacity
3 | jsonargparse
4 |
--------------------------------------------------------------------------------
/_requirements/test.txt:
--------------------------------------------------------------------------------
1 | httpx>=0.27.0
2 | coverage[toml] >=7.5.3
3 | pytest >=8.0
4 | pytest-cov
5 | mypy ==1.11.2
6 | pytest-asyncio
7 | asgi-lifespan
8 | python-multipart
9 | psutil
10 | requests
11 | lightning >2.0.0
12 | torch >2.0.0
13 | transformers
14 | openai>=1.12.0
15 | pillow
16 | numpy <2.0
17 | pytest-retry>=1.6.3
18 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [metadata]
2 | license_file = "LICENSE"
3 | description-file = "README.md"
4 |
5 | [build-system]
6 | requires = [
7 | "setuptools",
8 | "wheel",
9 | ]
10 |
11 |
12 | [tool.check-manifest]
13 | ignore = [
14 | "*.yml",
15 | ".github",
16 | ".github/*"
17 | ]
18 |
19 |
20 | [tool.pytest.ini_options]
21 | norecursedirs = [
22 | ".git",
23 | ".github",
24 | "dist",
25 | "build",
26 | "docs",
27 | ]
28 | addopts = [
29 | "--strict-markers",
30 | "--doctest-modules",
31 | "--color=yes",
32 | "--disable-pytest-warnings",
33 | ]
34 | filterwarnings = [
35 | "error::FutureWarning",
36 | ]
37 | xfail_strict = true
38 | junit_duration_report = "call"
39 |
40 | [tool.coverage.report]
41 | exclude_lines = [
42 | "pragma: no cover",
43 | "pass",
44 | ]
45 |
46 | [tool.codespell]
47 | #skip = '*.py'
48 | quiet-level = 3
49 | # comma separated list of words; waiting for:
50 | # https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603
51 | # also adding links until they ignored by its: nature
52 | # https://github.com/codespell-project/codespell/issues/2243#issuecomment-1732019960
53 | #ignore-words-list = ""
54 |
55 | [tool.docformatter]
56 | recursive = true
57 | wrap-summaries = 120
58 | wrap-descriptions = 120
59 | blank = true
60 |
61 |
62 | #[tool.mypy]
63 | #files = [
64 | # "src",
65 | #]
66 | #install_types = true
67 | #non_interactive = true
68 | #disallow_untyped_defs = true
69 | #ignore_missing_imports = true
70 | #show_error_codes = true
71 | #warn_redundant_casts = true
72 | #warn_unused_configs = true
73 | #warn_unused_ignores = true
74 | #allow_redefinition = true
75 | ## disable this rule as the Trainer attributes are defined in the connectors, not in its __init__
76 | #disable_error_code = "attr-defined"
77 | ## style choices
78 | #warn_no_return = false
79 |
80 |
81 | [tool.ruff]
82 | line-length = 120
83 | target-version = "py38"
84 | # Enable Pyflakes `E` and `F` codes by default.
85 | lint.select = [
86 | "E", "W", # see: https://pypi.org/project/pycodestyle
87 | "F", # see: https://pypi.org/project/pyflakes
88 | "N", # see: https://pypi.org/project/pep8-naming
89 | # "D", # see: https://pypi.org/project/pydocstyle
90 | "I", # implementation for isort
91 | "UP", # implementation for pyupgrade
92 | "RUF100", # implementation for yesqa
93 | ]
94 | lint.extend-select = [
95 | "C4", # see: https://pypi.org/project/flake8-comprehensions
96 | "PT", # see: https://pypi.org/project/flake8-pytest-style
97 | "RET", # see: https://pypi.org/project/flake8-return
98 | "SIM", # see: https://pypi.org/project/flake8-simplify
99 | ]
100 | lint.ignore = [
101 | "E731", # Do not assign a lambda expression, use a def
102 | ]
103 | # Exclude a variety of commonly ignored directories.
104 | exclude = [
105 | ".eggs",
106 | ".git",
107 | ".mypy_cache",
108 | ".ruff_cache",
109 | "__pypackages__",
110 | "_build",
111 | "build",
112 | "dist",
113 | "docs"
114 | ]
115 |
116 | [tool.ruff.lint.per-file-ignores]
117 | "setup.py" = ["D100", "SIM115"]
118 | "__about__.py" = ["D100"]
119 | "__init__.py" = ["D100"]
120 |
121 | [tool.ruff.lint.pydocstyle]
122 | # Use Google-style docstrings.
123 | convention = "google"
124 |
125 | #[tool.ruff.pycodestyle]
126 | #ignore-overlong-task-comments = true
127 |
128 | [tool.ruff.lint.mccabe]
129 | # Unlike Flake8, default to a complexity level of 10.
130 | max-complexity = 10
131 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | python_files = test_*.py
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi >=0.100
2 | uvicorn[standard] >=0.29.0
3 | pyzmq >=22.0.0
4 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright The Lightning AI team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import glob
16 | import os
17 | from importlib.util import module_from_spec, spec_from_file_location
18 | from pathlib import Path
19 |
20 | from pkg_resources import parse_requirements
21 | from setuptools import find_packages, setup
22 |
23 | _PATH_ROOT = os.path.dirname(__file__)
24 | _PATH_SOURCE = os.path.join(_PATH_ROOT, "src")
25 | _PATH_REQUIRES = os.path.join(_PATH_ROOT, "_requirements")
26 |
27 |
28 | def _load_py_module(fname, pkg="litserve"):
29 | spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname))
30 | py = module_from_spec(spec)
31 | spec.loader.exec_module(py)
32 | return py
33 |
34 |
35 | def _load_requirements(path_dir: str = _PATH_ROOT, file_name: str = "requirements.txt") -> list:
36 | reqs = parse_requirements(open(os.path.join(path_dir, file_name)).readlines())
37 | return list(map(str, reqs))
38 |
39 |
40 | about = _load_py_module("__about__.py")
41 | with open(os.path.join(_PATH_ROOT, "README.md"), encoding="utf-8") as fopen:
42 | readme = fopen.read()
43 |
44 |
45 | def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple = ("devel.txt", "docs.txt")) -> dict:
46 | # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras
47 | # Define package extras. These are only installed if you specify them.
48 | # From remote, use like `pip install pytorch-lightning[dev, docs]`
49 | # From local copy of repo, use like `pip install ".[dev, docs]"`
50 | req_files = [Path(p) for p in glob.glob(os.path.join(requirements_dir, "*.txt"))]
51 | extras = {
52 | p.stem: _load_requirements(file_name=p.name, path_dir=str(p.parent))
53 | for p in req_files
54 | if p.name not in skip_files
55 | }
56 | # todo: eventually add some custom aggregations such as `develop`
57 | extras = {name: sorted(set(reqs)) for name, reqs in extras.items()}
58 | print("The extras are: ", extras)
59 | return extras
60 |
61 |
62 | # https://packaging.python.org/discussions/install-requires-vs-requirements /
63 | # keep the meta-data here for simplicity in reading this file... it's not obvious
64 | # what happens and to non-engineers they won't know to look in init ...
65 | # the goal of the project is simplicity for researchers, don't want to add too much
66 | # engineer specific practices
67 | setup(
68 | name="litserve",
69 | version=about.__version__,
70 | description=about.__docs__,
71 | author=about.__author__,
72 | author_email=about.__author_email__,
73 | url=about.__homepage__,
74 | download_url="https://github.com/Lightning-AI/litserve",
75 | license=about.__license__,
76 | packages=find_packages(where="src"),
77 | package_dir={"": "src"},
78 | long_description=readme,
79 | long_description_content_type="text/markdown",
80 | include_package_data=True,
81 | zip_safe=False,
82 | keywords=["deep learning", "pytorch", "AI"],
83 | python_requires=">=3.8",
84 | setup_requires=["wheel"],
85 | install_requires=_load_requirements(),
86 | extras_require=_prepare_extras(),
87 | project_urls={
88 | "Bug Tracker": "https://github.com/Lightning-AI/litserve/issues",
89 | "Documentation": "https://lightning-ai.github.io/litserve/",
90 | "Source Code": "https://github.com/Lightning-AI/litserve",
91 | },
92 | classifiers=[
93 | "Environment :: Console",
94 | "Natural Language :: English",
95 | # How mature is this project? Common values are
96 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable
97 | "Development Status :: 3 - Alpha",
98 | # Indicate who your project is intended for
99 | "Intended Audience :: Developers",
100 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
101 | "Topic :: Scientific/Engineering :: Information Analysis",
102 | # Pick your license as you wish
103 | "License :: OSI Approved :: Apache Software License",
104 | "Operating System :: OS Independent",
105 | # Specify the Python versions you support here. In particular, ensure
106 | # that you indicate whether you support Python 2, Python 3 or both.
107 | "Programming Language :: Python :: 3",
108 | "Programming Language :: Python :: 3.8",
109 | "Programming Language :: Python :: 3.9",
110 | "Programming Language :: Python :: 3.10",
111 | "Programming Language :: Python :: 3.11",
112 | ],
113 | entry_points={
114 | "console_scripts": ["litserve=litserve.__main__:main", "lightning=litserve.cli:main"],
115 | },
116 | )
117 |
--------------------------------------------------------------------------------
/src/litserve/__about__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | __version__ = "0.2.11"
15 | __author__ = "Lightning-AI et al."
16 | __author_email__ = "community@lightning.ai"
17 | __license__ = "Apache-2.0"
18 | __copyright__ = f"Copyright (c) 2024, {__author__}."
19 | __homepage__ = "https://github.com/Lightning-AI/litserve"
20 | __docs__ = "Lightweight AI server."
21 |
22 | __all__ = [
23 | "__author__",
24 | "__author_email__",
25 | "__copyright__",
26 | "__docs__",
27 | "__homepage__",
28 | "__license__",
29 | "__version__",
30 | ]
31 |
--------------------------------------------------------------------------------
/src/litserve/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from litserve import test_examples
15 | from litserve.__about__ import * # noqa: F403
16 | from litserve.api import LitAPI
17 | from litserve.callbacks import Callback
18 | from litserve.loggers import Logger
19 | from litserve.server import LitServer, Request, Response
20 | from litserve.specs import OpenAIEmbeddingSpec, OpenAISpec
21 | from litserve.utils import configure_logging, set_trace, set_trace_if_debug
22 |
23 | configure_logging()
24 |
25 | __all__ = [
26 | "LitAPI",
27 | "LitServer",
28 | "Request",
29 | "Response",
30 | "OpenAISpec",
31 | "OpenAIEmbeddingSpec",
32 | "test_examples",
33 | "Callback",
34 | "Logger",
35 | "set_trace",
36 | "set_trace_if_debug",
37 | ]
38 |
--------------------------------------------------------------------------------
/src/litserve/__main__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, RawTextHelpFormatter
15 |
16 | from litserve.docker_builder import dockerize
17 |
18 |
19 | class LitFormatter(ArgumentDefaultsHelpFormatter, RawTextHelpFormatter): ...
20 |
21 |
22 | def main():
23 | parser = ArgumentParser(description="CLI for LitServe", formatter_class=LitFormatter)
24 | subparsers = parser.add_subparsers(
25 | dest="command",
26 | title="Commands",
27 | )
28 |
29 | # dockerize sub-command
30 | dockerize_parser = subparsers.add_parser(
31 | "dockerize",
32 | help="Generate a Dockerfile for the given server code.",
33 | description="Generate a Dockerfile for the given server code.\nExample usage:"
34 | " litserve dockerize server.py --port 8000 --gpu",
35 | formatter_class=LitFormatter,
36 | )
37 | dockerize_parser.add_argument(
38 | "server_filename",
39 | type=str,
40 | help="The path to the server file. Example: server.py or app.py.",
41 | )
42 | dockerize_parser.add_argument(
43 | "--port",
44 | type=int,
45 | default=8000,
46 | help="The port to expose in the Docker container.",
47 | )
48 | dockerize_parser.add_argument(
49 | "--gpu",
50 | default=False,
51 | action="store_true",
52 | help="Whether to use a GPU-enabled Docker image.",
53 | )
54 | dockerize_parser.set_defaults(func=lambda args: dockerize(args.server_filename, args.port, args.gpu))
55 | args = parser.parse_args()
56 |
57 | if hasattr(args, "func"):
58 | args.func(args)
59 | else:
60 | parser.print_help()
61 |
62 |
63 | if __name__ == "__main__":
64 | main()
65 |
--------------------------------------------------------------------------------
/src/litserve/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | from litserve.callbacks.base import Callback, CallbackRunner, EventTypes, NoopCallback
2 |
3 | __all__ = ["Callback", "CallbackRunner", "EventTypes", "NoopCallback"]
4 |
--------------------------------------------------------------------------------
/src/litserve/callbacks/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from abc import ABC
3 | from enum import Enum
4 | from typing import List, Union
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class EventTypes(Enum):
10 | BEFORE_SETUP = "on_before_setup"
11 | AFTER_SETUP = "on_after_setup"
12 | BEFORE_DECODE_REQUEST = "on_before_decode_request"
13 | AFTER_DECODE_REQUEST = "on_after_decode_request"
14 | BEFORE_ENCODE_RESPONSE = "on_before_encode_response"
15 | AFTER_ENCODE_RESPONSE = "on_after_encode_response"
16 | BEFORE_PREDICT = "on_before_predict"
17 | AFTER_PREDICT = "on_after_predict"
18 | ON_SERVER_START = "on_server_start"
19 | ON_SERVER_END = "on_server_end"
20 | ON_REQUEST = "on_request"
21 | ON_RESPONSE = "on_response"
22 |
23 |
24 | class Callback(ABC):
25 | def on_before_setup(self, *args, **kwargs):
26 | """Called before setup is started."""
27 |
28 | def on_after_setup(self, *args, **kwargs):
29 | """Called after setup is completed."""
30 |
31 | def on_before_decode_request(self, *args, **kwargs):
32 | """Called before request decoding is started."""
33 |
34 | def on_after_decode_request(self, *args, **kwargs):
35 | """Called after request decoding is completed."""
36 |
37 | def on_before_encode_response(self, *args, **kwargs):
38 | """Called before response encoding is started."""
39 |
40 | def on_after_encode_response(self, *args, **kwargs):
41 | """Called after response encoding is completed."""
42 |
43 | def on_before_predict(self, *args, **kwargs):
44 | """Called before prediction is started."""
45 |
46 | def on_after_predict(self, *args, **kwargs):
47 | """Called after prediction is completed."""
48 |
49 | def on_server_start(self, *args, **kwargs):
50 | """Called before server starts."""
51 |
52 | def on_server_end(self, *args, **kwargs):
53 | """Called when server terminates."""
54 |
55 | def on_request(self, *args, **kwargs):
56 | """Called when request enters the endpoint function."""
57 |
58 | def on_response(self, *args, **kwargs):
59 | """Called when response is generated from the worker and ready to return to the client."""
60 |
61 | # Adding a new hook? Register it with the EventTypes dataclass too,
62 |
63 |
64 | class CallbackRunner:
65 | def __init__(self, callbacks: Union[Callback, List[Callback]] = None):
66 | self._callbacks = []
67 | if callbacks:
68 | self._add_callbacks(callbacks)
69 |
70 | def _add_callbacks(self, callbacks: Union[Callback, List[Callback]]):
71 | if not isinstance(callbacks, list):
72 | callbacks = [callbacks]
73 | for callback in callbacks:
74 | if not isinstance(callback, Callback):
75 | raise ValueError(f"Invalid callback type: {callback}")
76 | self._callbacks.extend(callbacks)
77 |
78 | def trigger_event(self, event_name, *args, **kwargs):
79 | """Triggers an event, invoking all registered callbacks for that event."""
80 | for callback in self._callbacks:
81 | try:
82 | getattr(callback, event_name)(*args, **kwargs)
83 | except Exception:
84 | # Handle exceptions to prevent one callback from disrupting others
85 | logger.exception(f"Error in callback '{callback}' during event '{event_name}'")
86 |
87 |
88 | class NoopCallback(Callback):
89 | """This callback does nothing."""
90 |
--------------------------------------------------------------------------------
/src/litserve/callbacks/defaults/__init__.py:
--------------------------------------------------------------------------------
1 | from litserve.callbacks.defaults.metric_callback import PredictionTimeLogger
2 |
3 | __all__ = ["PredictionTimeLogger"]
4 |
--------------------------------------------------------------------------------
/src/litserve/callbacks/defaults/metric_callback.py:
--------------------------------------------------------------------------------
1 | import time
2 | import typing
3 |
4 | from litserve.callbacks.base import Callback
5 |
6 | if typing.TYPE_CHECKING:
7 | from litserve import LitAPI
8 |
9 |
10 | class PredictionTimeLogger(Callback):
11 | def on_before_predict(self, lit_api: "LitAPI"):
12 | self._start_time = time.perf_counter()
13 |
14 | def on_after_predict(self, lit_api: "LitAPI"):
15 | elapsed = time.perf_counter() - self._start_time
16 | print(f"Prediction took {elapsed:.2f} seconds", flush=True)
17 |
18 |
19 | class RequestTracker(Callback):
20 | def on_request(self, active_requests: int, **kwargs):
21 | print(f"Active requests: {active_requests}", flush=True)
22 |
--------------------------------------------------------------------------------
/src/litserve/cli.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 | import subprocess
3 | import sys
4 |
5 |
6 | def _ensure_lightning_installed():
7 | if not importlib.util.find_spec("lightning_sdk"):
8 | print("Lightning CLI not found. Installing...")
9 | subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "lightning-sdk"])
10 |
11 |
12 | def main():
13 | _ensure_lightning_installed()
14 |
15 | try:
16 | # Import the correct entry point for lightning_sdk
17 | from lightning_sdk.cli.entrypoint import main_cli
18 |
19 | # Call the lightning CLI's main function directly with our arguments
20 | # This bypasses the command-line entry point completely
21 | sys.argv[0] = "lightning" # Make it think it was called as "lightning"
22 | main_cli()
23 | except ImportError as e:
24 | # If there's an issue importing or finding the right module
25 | print(f"Error importing lightning_sdk CLI: {e}")
26 | print("Please ensure `lightning-sdk` is installed correctly.")
27 | sys.exit(1)
28 |
--------------------------------------------------------------------------------
/src/litserve/connector.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import platform
16 | import subprocess
17 | import sys
18 | from functools import lru_cache
19 | from typing import List, Optional, Union
20 |
21 |
22 | class _Connector:
23 | def __init__(self, accelerator: str = "auto", devices: Union[List[int], int, str] = "auto"):
24 | accelerator = self._sanitize_accelerator(accelerator)
25 | if accelerator in ("cpu", "cuda", "mps"):
26 | self._accelerator = accelerator
27 | elif accelerator == "auto":
28 | self._accelerator = self._choose_auto_accelerator()
29 | elif accelerator == "gpu":
30 | self._accelerator = self._choose_gpu_accelerator_backend()
31 |
32 | if devices == "auto":
33 | self._devices = self._accelerator_device_count()
34 | else:
35 | self._devices = devices
36 |
37 | self.check_devices_and_accelerators()
38 |
39 | def check_devices_and_accelerators(self):
40 | """Check if the devices are in a valid fomra and raise an error if they are not."""
41 | if self._accelerator in ("cuda", "mps"):
42 | if not isinstance(self._devices, int) and not (
43 | isinstance(self._devices, list) and all(isinstance(device, int) for device in self._devices)
44 | ):
45 | raise ValueError(
46 | "devices must be an integer or a list of integers when using 'cuda' or 'mps', "
47 | f"instead got {self._devices}"
48 | )
49 | elif self._accelerator != "cpu":
50 | raise ValueError(f"accelerator must be one of (cuda, mps, cpu), instead got {self._accelerator}")
51 |
52 | @property
53 | def accelerator(self):
54 | return self._accelerator
55 |
56 | @property
57 | def devices(self):
58 | return self._devices
59 |
60 | @staticmethod
61 | def _sanitize_accelerator(accelerator: Optional[str]):
62 | if isinstance(accelerator, str):
63 | accelerator = accelerator.lower()
64 |
65 | if accelerator not in ["auto", "cpu", "mps", "cuda", "gpu", None]:
66 | raise ValueError(f"accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'. Found: {accelerator}")
67 |
68 | if accelerator is None:
69 | return "auto"
70 | return accelerator
71 |
72 | def _choose_auto_accelerator(self):
73 | gpu_backend = self._choose_gpu_accelerator_backend()
74 | if "torch" in sys.modules and gpu_backend:
75 | return gpu_backend
76 | return "cpu"
77 |
78 | def _accelerator_device_count(self) -> int:
79 | if self._accelerator == "cuda":
80 | return check_cuda_with_nvidia_smi()
81 | return 1
82 |
83 | @staticmethod
84 | def _choose_gpu_accelerator_backend():
85 | if check_cuda_with_nvidia_smi() > 0:
86 | return "cuda"
87 |
88 | try:
89 | import torch
90 |
91 | if torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"):
92 | return "mps"
93 | except ImportError:
94 | return None
95 |
96 | return None
97 |
98 |
99 | @lru_cache(maxsize=1)
100 | def check_cuda_with_nvidia_smi() -> int:
101 | """Checks if CUDA is installed using the `nvidia-smi` command-line tool.
102 |
103 | Returns count of visible devices.
104 |
105 | """
106 | try:
107 | nvidia_smi_output = subprocess.check_output(["nvidia-smi", "-L"]).decode("utf-8").strip()
108 | devices = [el for el in nvidia_smi_output.split("\n") if el.startswith("GPU")]
109 | devices = [el.split(":")[0].split()[1] for el in devices]
110 | visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
111 | if visible_devices:
112 | # we need check the intersection of devices and visible devices, since
113 | # using CUDA_VISIBLE_DEVICES=0,25 on a 4-GPU machine will yield
114 | # torch.cuda.device_count() == 1
115 | devices = [el for el in devices if el in visible_devices.split(",")]
116 | return len(devices)
117 | except (subprocess.CalledProcessError, FileNotFoundError):
118 | return 0
119 |
--------------------------------------------------------------------------------
/src/litserve/docker_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import logging
15 | import os
16 | import warnings
17 | from pathlib import Path
18 |
19 | import litserve as ls
20 |
21 | logger = logging.getLogger(__name__)
22 | logger.setLevel(logging.INFO)
23 | logger.propagate = False
24 | console_handler = logging.StreamHandler()
25 | console_handler.setLevel(logging.INFO)
26 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
27 | console_handler.setFormatter(formatter)
28 | logger.addHandler(console_handler)
29 |
30 | # COLOR CODES
31 | RESET = "\u001b[0m"
32 | RED = "\u001b[31m"
33 | GREEN = "\u001b[32m"
34 | BLUE = "\u001b[34m"
35 | MAGENTA = "\u001b[35m"
36 | BG_MAGENTA = "\u001b[45m"
37 |
38 | # ACTION CODES
39 | BOLD = "\u001b[1m"
40 | UNDERLINE = "\u001b[4m"
41 | INFO = f"{BOLD}{BLUE}[INFO]"
42 | WARNING = f"{BOLD}{RED}[WARNING]"
43 |
44 |
45 | def color(text, color_code, action_code=None):
46 | if action_code:
47 | return f"{action_code} {color_code}{text}{RESET}"
48 | return f"{color_code}{text}{RESET}"
49 |
50 |
51 | REQUIREMENTS_FILE = "requirements.txt"
52 | DOCKERFILE_TEMPLATE = """ARG PYTHON_VERSION=3.12
53 | FROM python:$PYTHON_VERSION-slim
54 |
55 | ####### Add your own installation commands here #######
56 | # RUN pip install some-package
57 | # RUN wget https://path/to/some/data/or/weights
58 | # RUN apt-get update && apt-get install -y
59 |
60 | WORKDIR /app
61 | COPY . /app
62 |
63 | # Install litserve and requirements
64 | RUN pip install --no-cache-dir litserve=={version} {requirements}
65 | EXPOSE {port}
66 | CMD ["python", "/app/{server_filename}"]
67 | """
68 |
69 | CUDA_DOCKER_TEMPLATE = """# Change CUDA and cuDNN version here
70 | FROM nvidia/cuda:12.4.1-base-ubuntu22.04
71 | ARG PYTHON_VERSION=3.12
72 |
73 | ENV DEBIAN_FRONTEND=noninteractive
74 | RUN apt-get update && apt-get install -y --no-install-recommends \\
75 | software-properties-common \\
76 | wget \\
77 | && add-apt-repository ppa:deadsnakes/ppa \\
78 | && apt-get update && apt-get install -y --no-install-recommends \\
79 | python$PYTHON_VERSION \\
80 | python$PYTHON_VERSION-dev \\
81 | python$PYTHON_VERSION-venv \\
82 | && wget https://bootstrap.pypa.io/get-pip.py -O get-pip.py \\
83 | && python$PYTHON_VERSION get-pip.py \\
84 | && rm get-pip.py \\
85 | && ln -sf /usr/bin/python$PYTHON_VERSION /usr/bin/python \\
86 | && ln -sf /usr/local/bin/pip$PYTHON_VERSION /usr/local/bin/pip \\
87 | && python --version \\
88 | && pip --version \\
89 | && apt-get purge -y --auto-remove software-properties-common \\
90 | && apt-get clean \\
91 | && rm -rf /var/lib/apt/lists/*
92 |
93 | ####### Add your own installation commands here #######
94 | # RUN pip install some-package
95 | # RUN wget https://path/to/some/data/or/weights
96 | # RUN apt-get update && apt-get install -y
97 |
98 | WORKDIR /app
99 | COPY . /app
100 |
101 | # Install litserve and requirements
102 | RUN pip install --no-cache-dir litserve=={version} {requirements}
103 | EXPOSE {port}
104 | CMD ["python", "/app/{server_filename}"]
105 | """
106 |
107 | # Link our documentation as the bottom of this msg
108 | SUCCESS_MSG = """{BOLD}{MAGENTA}Dockerfile created successfully{RESET}
109 | Update {UNDERLINE}{dockerfile_path}{RESET} to add any additional dependencies or commands.{RESET}
110 |
111 | {BOLD}Build the container with:{RESET}
112 | > {UNDERLINE}docker build -t litserve-model .{RESET}
113 |
114 | {BOLD}To run the Docker container on the machine:{RESET}
115 | > {UNDERLINE}{RUN_CMD}{RESET}
116 |
117 | {BOLD}To push the container to a registry:{RESET}
118 | > {UNDERLINE}docker push litserve-model{RESET}
119 | """
120 |
121 |
122 | def dockerize(server_filename: str, port: int = 8000, gpu: bool = False):
123 | """Generate a Dockerfile for the given server code.
124 |
125 | Example usage:
126 | litserve dockerize server.py --port 8000 --gpu
127 |
128 | Args:
129 | server_filename (str): The path to the server file. Example sever.py or app.py.
130 | port (int, optional): The port to expose in the Docker container.
131 | gpu (bool, optional): Whether to use a GPU-enabled Docker image.
132 |
133 | """
134 | requirements = ""
135 | if os.path.exists(REQUIREMENTS_FILE):
136 | requirements = f"-r {REQUIREMENTS_FILE}"
137 | else:
138 | warnings.warn(
139 | f"requirements.txt not found at {os.getcwd()}. "
140 | f"Make sure to install the required packages in the Dockerfile.",
141 | UserWarning,
142 | )
143 |
144 | current_dir = Path.cwd()
145 | if not (current_dir / server_filename).is_file():
146 | raise FileNotFoundError(f"Server file `{server_filename}` must be in the current directory: {os.getcwd()}")
147 |
148 | version = ls.__version__
149 | if gpu:
150 | run_cmd = f"docker run --gpus all -p {port}:{port} litserve-model:latest"
151 | docker_template = CUDA_DOCKER_TEMPLATE
152 | else:
153 | run_cmd = f"docker run -p {port}:{port} litserve-model:latest"
154 | docker_template = DOCKERFILE_TEMPLATE
155 | dockerfile_content = docker_template.format(
156 | server_filename=server_filename,
157 | port=port,
158 | version=version,
159 | requirements=requirements,
160 | )
161 | with open("Dockerfile", "w") as f:
162 | f.write(dockerfile_content)
163 | success_msg = SUCCESS_MSG.format(
164 | dockerfile_path=os.path.abspath("Dockerfile"),
165 | RUN_CMD=run_cmd,
166 | BOLD=BOLD,
167 | MAGENTA=MAGENTA,
168 | GREEN=GREEN,
169 | BLUE=BLUE,
170 | UNDERLINE=UNDERLINE,
171 | BG_MAGENTA=BG_MAGENTA,
172 | RESET=RESET,
173 | )
174 | print(success_msg)
175 |
--------------------------------------------------------------------------------
/src/litserve/loggers.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import functools
15 | import logging
16 | import multiprocessing as mp
17 | import pickle
18 | from abc import ABC, abstractmethod
19 | from typing import TYPE_CHECKING, List, Optional, Union
20 |
21 | from starlette.types import ASGIApp
22 |
23 | module_logger = logging.getLogger(__name__)
24 |
25 | if TYPE_CHECKING: # pragma: no cover
26 | from litserve import LitServer
27 |
28 |
29 | class Logger(ABC):
30 | def __init__(self):
31 | self._config = {}
32 |
33 | def mount(self, path: str, app: ASGIApp) -> None:
34 | """Mount an ASGI app endpoint to LitServer. Use this method when you want to add an additional endpoint to the
35 | server such as /metrics endpoint for prometheus metrics.
36 |
37 | Args:
38 | path (str): The path to mount the app to.
39 | app (ASGIApp): The ASGI app to mount.
40 |
41 | """
42 | self._config.update({"mount": {"path": path, "app": app}})
43 |
44 | @abstractmethod
45 | def process(self, key, value):
46 | """Process a log entry from the log queue.
47 |
48 | This method should be implemented to define the specific logic for processing
49 | log entries.
50 |
51 | Args:
52 | key (str): The key associated with the log entry, typically indicating the type or category of the log.
53 | value (Any): The value associated with the log entry, containing the actual log data.
54 |
55 | Raises:
56 | NotImplementedError: This method must be overridden by subclasses. If not, calling this method will raise
57 | a NotImplementedError.
58 |
59 | Example:
60 | Here is an example of a Logger that logs monitoring metrics using Prometheus:
61 |
62 | from prometheus_client import Counter
63 |
64 | class PrometheusLogger(Logger):
65 | def __init__(self):
66 | super().__init__()
67 | self._metric_counter = Counter('log_entries', 'Count of log entries')
68 |
69 | def process(self, key, value):
70 | # Increment the Prometheus counter for each log entry
71 | self._metric_counter.inc()
72 | print(f"Logged {key}: {value}")
73 |
74 | """
75 | raise NotImplementedError # pragma: no cover
76 |
77 |
78 | class _LoggerProxy:
79 | def __init__(self, logger_class):
80 | self.logger_class = logger_class
81 |
82 | def create_logger(self):
83 | return self.logger_class()
84 |
85 |
86 | class _LoggerConnector:
87 | """_LoggerConnector is responsible for connecting Logger instances with the LitServer and managing their lifecycle.
88 |
89 | This class handles the following tasks:
90 | - Manages a queue (multiprocessing.Queue) where log data is placed using the LitAPI.log method.
91 | - Initiates a separate process to consume the log queue and process the log data using the associated
92 | Logger instances.
93 |
94 | """
95 |
96 | def __init__(self, lit_server: "LitServer", loggers: Optional[Union[List[Logger], Logger]] = None):
97 | self._loggers = []
98 | self._lit_server = lit_server
99 | if loggers is None:
100 | return # No loggers to add
101 | if isinstance(loggers, list):
102 | for logger in loggers:
103 | if not isinstance(logger, Logger):
104 | raise ValueError("Logger must be an instance of litserve.Logger")
105 | self.add_logger(logger)
106 | elif isinstance(loggers, Logger):
107 | self.add_logger(loggers)
108 | else:
109 | raise ValueError("loggers must be a list or an instance of litserve.Logger")
110 |
111 | def _mount(self, path: str, app: ASGIApp) -> None:
112 | self._lit_server.app.mount(path, app)
113 |
114 | def add_logger(self, logger: Logger):
115 | self._loggers.append(logger)
116 | if "mount" in logger._config:
117 | self._mount(logger._config["mount"]["path"], logger._config["mount"]["app"])
118 |
119 | @staticmethod
120 | def _is_picklable(obj):
121 | try:
122 | pickle.dumps(obj)
123 | return True
124 | except (pickle.PicklingError, TypeError, AttributeError):
125 | module_logger.warning(f"Logger {obj.__class__.__name__} is not pickleable and might not work properly.")
126 | return False
127 |
128 | @staticmethod
129 | def _process_logger_queue(logger_proxies: List[_LoggerProxy], queue):
130 | loggers = [proxy if isinstance(proxy, Logger) else proxy.create_logger() for proxy in logger_proxies]
131 | while True:
132 | key, value = queue.get()
133 | for logger in loggers:
134 | try:
135 | logger.process(key, value)
136 | except Exception as e:
137 | module_logger.error(
138 | f"{logger.__class__.__name__} ran into an error while processing log for entry "
139 | f"with key {key} and value {value}: {e}"
140 | )
141 |
142 | @functools.cache # Run once per LitServer instance
143 | def run(self, lit_server: "LitServer"):
144 | queue = lit_server.logger_queue
145 | lit_server.litapi_connector.set_logger_queue(queue)
146 |
147 | # Disconnect the logger connector from the LitServer to avoid pickling issues
148 | self._lit_server = None
149 |
150 | if not self._loggers:
151 | return
152 |
153 | # Create proxies for loggers
154 | logger_proxies = []
155 | for logger in self._loggers:
156 | if self._is_picklable(logger):
157 | logger_proxies.append(logger)
158 | else:
159 | logger_proxies.append(_LoggerProxy(logger.__class__))
160 |
161 | module_logger.debug(f"Starting logger process with {len(logger_proxies)} loggers")
162 | ctx = mp.get_context("spawn")
163 | process = ctx.Process(
164 | target=_LoggerConnector._process_logger_queue,
165 | args=(
166 | logger_proxies,
167 | queue,
168 | ),
169 | )
170 | process.start()
171 |
--------------------------------------------------------------------------------
/src/litserve/loops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import multiprocessing as mp
15 |
16 | from litserve.loops.base import LitLoop, _BaseLoop
17 | from litserve.loops.continuous_batching_loop import ContinuousBatchingLoop, Output
18 | from litserve.loops.loops import (
19 | get_default_loop,
20 | inference_worker,
21 | )
22 | from litserve.loops.simple_loops import BatchedLoop, SingleLoop
23 | from litserve.loops.streaming_loops import BatchedStreamingLoop, StreamingLoop
24 |
25 | mp.allow_connection_pickling()
26 |
27 | __all__ = [
28 | "_BaseLoop",
29 | "LitLoop",
30 | "get_default_loop",
31 | "inference_worker",
32 | "Output",
33 | "SingleLoop",
34 | "BatchedLoop",
35 | "StreamingLoop",
36 | "BatchedStreamingLoop",
37 | "ContinuousBatchingLoop",
38 | ]
39 |
--------------------------------------------------------------------------------
/src/litserve/loops/loops.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import logging
15 | from queue import Queue
16 | from typing import Dict
17 |
18 | from litserve import LitAPI
19 | from litserve.callbacks import CallbackRunner, EventTypes
20 | from litserve.loops.base import LitLoop, _BaseLoop
21 | from litserve.loops.simple_loops import BatchedLoop, SingleLoop
22 | from litserve.loops.streaming_loops import BatchedStreamingLoop, StreamingLoop
23 | from litserve.transport.base import MessageTransport
24 | from litserve.utils import WorkerSetupStatus
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | def get_default_loop(stream: bool, max_batch_size: int, enable_async: bool = False) -> _BaseLoop:
30 | """Get the default loop based on the stream flag, batch size, and async support.
31 |
32 | Args:
33 | stream: Whether streaming is enabled
34 | max_batch_size: Maximum batch size
35 | enable_async: Whether async support is enabled (supports both coroutines and async generators)
36 |
37 | Returns:
38 | The appropriate loop implementation
39 |
40 | Raises:
41 | ValueError: If async and batching are enabled together (not supported)
42 |
43 | """
44 | if enable_async:
45 | if max_batch_size > 1:
46 | raise ValueError("Async batching is not supported. Please use enable_async=False with batching.")
47 | if stream:
48 | return StreamingLoop() # StreamingLoop now supports async
49 | return SingleLoop() # Only SingleLoop supports async currently
50 |
51 | if stream:
52 | if max_batch_size > 1:
53 | return BatchedStreamingLoop()
54 | return StreamingLoop()
55 |
56 | if max_batch_size > 1:
57 | return BatchedLoop()
58 | return SingleLoop()
59 |
60 |
61 | def inference_worker(
62 | lit_api: LitAPI,
63 | device: str,
64 | worker_id: int,
65 | request_queue: Queue,
66 | transport: MessageTransport,
67 | workers_setup_status: Dict[int, str],
68 | callback_runner: CallbackRunner,
69 | ):
70 | print("workers_setup_status", workers_setup_status)
71 | lit_spec = lit_api.spec
72 | loop: LitLoop = lit_api.loop
73 | stream = lit_api.stream
74 |
75 | endpoint = lit_api.api_path.split("/")[-1]
76 |
77 | callback_runner.trigger_event(EventTypes.BEFORE_SETUP.value, lit_api=lit_api)
78 | try:
79 | lit_api.setup(device)
80 | except Exception:
81 | logger.exception(f"Error setting up worker {worker_id}.")
82 | workers_setup_status[f"{endpoint}_{worker_id}"] = WorkerSetupStatus.ERROR
83 | return
84 | lit_api.device = device
85 | callback_runner.trigger_event(EventTypes.AFTER_SETUP.value, lit_api=lit_api)
86 |
87 | print(f"Setup complete for worker {f'{endpoint}_{worker_id}'}.")
88 |
89 | if workers_setup_status:
90 | workers_setup_status[f"{endpoint}_{worker_id}"] = WorkerSetupStatus.READY
91 |
92 | if lit_spec:
93 | logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")
94 |
95 | if loop == "auto":
96 | loop = get_default_loop(stream, lit_api.max_batch_size, lit_api.enable_async)
97 |
98 | loop(
99 | lit_api,
100 | device,
101 | worker_id,
102 | request_queue,
103 | transport,
104 | workers_setup_status,
105 | callback_runner,
106 | )
107 |
--------------------------------------------------------------------------------
/src/litserve/middlewares.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import logging
15 | import multiprocessing
16 | from typing import Optional
17 |
18 | from fastapi import HTTPException
19 | from starlette.middleware.base import BaseHTTPMiddleware
20 | from starlette.types import ASGIApp, Message, Receive, Scope, Send
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | class MaxSizeMiddleware(BaseHTTPMiddleware):
26 | """Rejects requests with a payload that is too large."""
27 |
28 | def __init__(
29 | self,
30 | app: ASGIApp,
31 | *,
32 | max_size: Optional[int] = None,
33 | ) -> None:
34 | self.app = app
35 | self.max_size = max_size
36 |
37 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
38 | if scope["type"] != "http":
39 | await self.app(scope, receive, send)
40 | return
41 |
42 | total_size = 0
43 |
44 | async def rcv() -> Message:
45 | nonlocal total_size
46 | message = await receive()
47 | chunk_size = len(message.get("body", b""))
48 | total_size += chunk_size
49 | if self.max_size is not None and total_size > self.max_size:
50 | raise HTTPException(413, "Payload too large")
51 | return message
52 |
53 | await self.app(scope, rcv, send)
54 |
55 |
56 | class RequestCountMiddleware(BaseHTTPMiddleware):
57 | """Adds a header to the response with the number of active requests."""
58 |
59 | def __init__(self, app: ASGIApp, active_counter: multiprocessing.Value) -> None:
60 | self.app = app
61 | self.active_counter = active_counter
62 |
63 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
64 | if scope["type"] != "http" or (scope["type"] == "http" and scope["path"] in ["/", "/health", "/metrics"]):
65 | await self.app(scope, receive, send)
66 | return
67 |
68 | self.active_counter.value += 1
69 | await self.app(scope, receive, send)
70 | self.active_counter.value -= 1
71 |
--------------------------------------------------------------------------------
/src/litserve/python_client.py:
--------------------------------------------------------------------------------
1 | client_template = """# This file is auto-generated by LitServe.
2 | # Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`.
3 |
4 | import requests
5 |
6 | response = requests.post("http://127.0.0.1:{PORT}/predict", json={{"input": 4.0}})
7 | print(f"Status: {{response.status_code}}\\nResponse:\\n {{response.text}}")
8 | """
9 |
--------------------------------------------------------------------------------
/src/litserve/schema/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-AI/LitServe/ea6aca44640d4d5691ba44fb27e76796d83482de/src/litserve/schema/__init__.py
--------------------------------------------------------------------------------
/src/litserve/schema/image.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from io import BytesIO
3 | from typing import TYPE_CHECKING, Any, Optional
4 |
5 | from pydantic import BaseModel, field_serializer, model_validator
6 |
7 | if TYPE_CHECKING:
8 | from PIL import Image
9 |
10 |
11 | class ImageInput(BaseModel):
12 | image_data: Optional[str] = None
13 |
14 | @model_validator(mode="after")
15 | def validate_base64(self) -> "ImageInput":
16 | """Ensure the string is a valid Base64."""
17 | model_dump = self.model_dump()
18 | for key, value in model_dump.items():
19 | if value:
20 | try:
21 | base64.b64decode(value)
22 | except base64.binascii.Error:
23 | raise ValueError("Invalid Base64 string.")
24 | return self
25 |
26 | def get_image(self, key: Optional[str] = None) -> "Image.Image":
27 | """Decode the Base64 string and return a PIL Image object."""
28 | if key is None:
29 | key = "image_data"
30 | image_data = self.model_dump().get(key)
31 | if not image_data:
32 | raise ValueError(f"Missing image data for key '{key}'")
33 | try:
34 | from PIL import Image, UnidentifiedImageError
35 | except ImportError:
36 | raise ImportError("Pillow is required to use the ImageInput schema. Install it with `pip install Pillow`.")
37 | try:
38 | decoded_data = base64.b64decode(image_data)
39 | return Image.open(BytesIO(decoded_data))
40 | except UnidentifiedImageError as e:
41 | raise ValueError(f"Error loading image from decoded data: {e}")
42 |
43 |
44 | class ImageOutput(BaseModel):
45 | image: Any
46 |
47 | @field_serializer("image")
48 | def serialize_image(self, image: Any, _info):
49 | """
50 | Serialize a PIL Image into a base64 string.
51 | Args:
52 | image (Any): The image object to serialize.
53 | _info: Metadata passed during serialization (not used here).
54 |
55 | Returns:
56 | str: Base64-encoded image string.
57 | """
58 | try:
59 | from PIL import Image
60 | except ImportError:
61 | raise ImportError("Pillow is required to use the ImageOutput schema. Install it with `pip install Pillow`.")
62 |
63 | if not isinstance(image, Image.Image):
64 | raise TypeError(f"Expected a PIL Image, got {type(image)}")
65 |
66 | # Save the image to a BytesIO buffer
67 | buffer = BytesIO()
68 | image.save(buffer, format="PNG") # Default format is PNG
69 | buffer.seek(0)
70 |
71 | # Encode the buffer content to base64
72 | base64_bytes = base64.b64encode(buffer.read())
73 |
74 | # Decode to string for JSON serialization
75 | return base64_bytes.decode("utf-8")
76 |
--------------------------------------------------------------------------------
/src/litserve/specs/__init__.py:
--------------------------------------------------------------------------------
1 | from litserve.specs.openai import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, OpenAISpec
2 | from litserve.specs.openai_embedding import EmbeddingRequest, EmbeddingResponse, OpenAIEmbeddingSpec
3 |
4 | __all__ = [
5 | "OpenAISpec",
6 | "OpenAIEmbeddingSpec",
7 | "EmbeddingRequest",
8 | "EmbeddingResponse",
9 | "ChatCompletionRequest",
10 | "ChatCompletionResponse",
11 | "ChatCompletionChunk",
12 | ]
13 |
--------------------------------------------------------------------------------
/src/litserve/specs/base.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from abc import abstractmethod
15 | from typing import TYPE_CHECKING, Callable, List
16 |
17 | if TYPE_CHECKING:
18 | from litserve import LitAPI, LitServer
19 |
20 |
21 | class LitSpec:
22 | """Spec will have its own encode, and decode."""
23 |
24 | def __init__(self):
25 | self._endpoints = []
26 | self.api_path = None
27 | self._server: LitServer = None
28 | self._max_batch_size = 1
29 | self.response_buffer = None
30 | self.request_queue = None
31 | self.response_queue_id = None
32 |
33 | @property
34 | def stream(self):
35 | return False
36 |
37 | def pre_setup(self, lit_api: "LitAPI"):
38 | pass
39 |
40 | def setup(self, server: "LitServer"):
41 | """This method is called by the server to connect the spec to the server."""
42 | self.response_buffer = server.response_buffer
43 | self.request_queue = server._get_request_queue(self.api_path)
44 | self.data_streamer = server.data_streamer
45 |
46 | def add_endpoint(self, path: str, endpoint: Callable, methods: List[str]):
47 | """Register an endpoint in the spec."""
48 | self._endpoints.append((path, endpoint, methods))
49 |
50 | @property
51 | def endpoints(self):
52 | return self._endpoints.copy()
53 |
54 | @abstractmethod
55 | def decode_request(self, request, meta_kwargs):
56 | """Convert the request payload to your model input."""
57 | pass
58 |
59 | @abstractmethod
60 | def encode_response(self, output, meta_kwargs):
61 | """Convert the model output to a response payload.
62 |
63 | To enable streaming, it should yield the output.
64 |
65 | """
66 | pass
67 |
--------------------------------------------------------------------------------
/src/litserve/test_examples/__init__.py:
--------------------------------------------------------------------------------
1 | from litserve.test_examples.openai_spec_example import (
2 | OpenAIBatchContext,
3 | TestAPI,
4 | TestAPIWithCustomEncode,
5 | TestAPIWithStructuredOutput,
6 | TestAPIWithToolCalls,
7 | )
8 | from litserve.test_examples.simple_example import SimpleBatchedAPI, SimpleLitAPI, SimpleStreamAPI, SimpleTorchAPI
9 |
10 | __all__ = [
11 | "SimpleLitAPI",
12 | "SimpleBatchedAPI",
13 | "SimpleTorchAPI",
14 | "TestAPI",
15 | "TestAPIWithCustomEncode",
16 | "TestAPIWithStructuredOutput",
17 | "TestAPIWithToolCalls",
18 | "OpenAIBatchContext",
19 | "SimpleStreamAPI",
20 | ]
21 |
--------------------------------------------------------------------------------
/src/litserve/test_examples/openai_embedding_spec_example.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 |
5 | from litserve.api import LitAPI
6 |
7 |
8 | class TestEmbedAPI(LitAPI):
9 | def setup(self, device):
10 | self.model = None
11 |
12 | def predict(self, x) -> List[List[float]]:
13 | n = len(x) if isinstance(x, list) else 1
14 | return np.random.rand(n, 768).tolist()
15 |
16 | def encode_response(self, output) -> dict:
17 | return {"embeddings": output}
18 |
19 |
20 | class TestEmbedBatchedAPI(TestEmbedAPI):
21 | def predict(self, batch) -> List[List[List[float]]]:
22 | return [np.random.rand(len(x), 768).tolist() for x in batch]
23 |
24 |
25 | class TestEmbedAPIWithUsage(TestEmbedAPI):
26 | def encode_response(self, output) -> dict:
27 | return {"embeddings": output, "prompt_tokens": 10, "total_tokens": 10}
28 |
29 |
30 | class TestEmbedAPIWithYieldPredict(TestEmbedAPI):
31 | def predict(self, x):
32 | yield from np.random.rand(768).tolist()
33 |
34 |
35 | class TestEmbedAPIWithYieldEncodeResponse(TestEmbedAPI):
36 | def encode_response(self, output):
37 | yield {"embeddings": output}
38 |
39 |
40 | class TestEmbedAPIWithNonDictOutput(TestEmbedAPI):
41 | def encode_response(self, output):
42 | return output
43 |
44 |
45 | class TestEmbedAPIWithMissingEmbeddings(TestEmbedAPI):
46 | def encode_response(self, output):
47 | return {"output": output}
48 |
--------------------------------------------------------------------------------
/src/litserve/test_examples/openai_spec_example.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import time
15 |
16 | from litserve.api import LitAPI
17 | from litserve.specs.openai import ChatMessage
18 |
19 |
20 | class TestAPI(LitAPI):
21 | def setup(self, device):
22 | self.model = None
23 |
24 | def predict(self, x):
25 | yield "This is a generated output"
26 |
27 |
28 | class TestAPIWithCustomEncode(TestAPI):
29 | def encode_response(self, output):
30 | yield ChatMessage(role="assistant", content="This is a custom encoded output")
31 |
32 |
33 | class TestAPIWithToolCalls(TestAPI):
34 | def encode_response(self, output):
35 | yield ChatMessage(
36 | role="assistant",
37 | content="",
38 | tool_calls=[
39 | {
40 | "id": "call_1",
41 | "type": "function",
42 | "function": {"name": "function_1", "arguments": '{"arg_1": "arg_1_value"}'},
43 | }
44 | ],
45 | )
46 |
47 |
48 | class TestAPIWithStructuredOutput(TestAPI):
49 | def encode_response(self, output):
50 | yield ChatMessage(
51 | role="assistant",
52 | content='{"name": "Science Fair", "date": "Friday", "participants": ["Alice", "Bob"]}',
53 | )
54 |
55 |
56 | class OpenAIBatchContext(LitAPI):
57 | def setup(self, device: str) -> None:
58 | self.model = None
59 |
60 | def batch(self, inputs):
61 | return inputs
62 |
63 | def predict(self, inputs, context):
64 | n = len(inputs)
65 | assert isinstance(context, list)
66 | for ctx in context:
67 | ctx["temperature"] = 1.0
68 | output = [
69 | "Hi!",
70 | "It's",
71 | "nice",
72 | "to",
73 | "meet",
74 | "you.",
75 | "Is",
76 | "there",
77 | "something",
78 | "I",
79 | "can",
80 | "help",
81 | "you",
82 | "with",
83 | "or",
84 | "would",
85 | "you",
86 | "like",
87 | "to",
88 | "chat?",
89 | ]
90 | for out in output:
91 | time.sleep(0.01) # fake delay
92 | yield [out + " "] * n
93 |
94 | def unbatch(self, output):
95 | return output
96 |
97 | def encode_response(self, output_stream, context):
98 | for outputs in output_stream:
99 | yield [{"role": "assistant", "content": output} for output in outputs]
100 | for ctx in context:
101 | assert ctx["temperature"] == 1.0, f"context {ctx} is not 1.0"
102 |
103 |
104 | class OpenAIWithUsage(LitAPI):
105 | def setup(self, device):
106 | self.model = None
107 |
108 | def predict(self, x):
109 | yield {
110 | "role": "assistant",
111 | "content": "10 + 6 is equal to 16.",
112 | "prompt_tokens": 25,
113 | "completion_tokens": 10,
114 | "total_tokens": 35,
115 | }
116 |
117 |
118 | class OpenAIWithUsageEncodeResponse(LitAPI):
119 | def setup(self, device):
120 | self.model = None
121 |
122 | def predict(self, x):
123 | # streaming tokens
124 | yield from ["10", " +", " ", "6", " is", " equal", " to", " ", "16", "."]
125 |
126 | def encode_response(self, output):
127 | for out in output:
128 | yield {"role": "assistant", "content": out}
129 |
130 | yield {"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35}
131 |
132 |
133 | class OpenAIBatchingWithUsage(OpenAIWithUsage):
134 | def batch(self, inputs):
135 | return inputs
136 |
137 | def predict(self, x):
138 | n = len(x)
139 | yield ["10 + 6 is equal to 16."] * n
140 |
141 | def encode_response(self, output_stream_batch, context):
142 | n = len(context)
143 | for output_batch in output_stream_batch:
144 | yield [{"role": "assistant", "content": out} for out in output_batch]
145 |
146 | yield [
147 | {"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35}
148 | ] * n
149 |
150 | def unbatch(self, output):
151 | return output
152 |
--------------------------------------------------------------------------------
/src/litserve/test_examples/simple_example.py:
--------------------------------------------------------------------------------
1 | from litserve.api import LitAPI
2 |
3 |
4 | class SimpleLitAPI(LitAPI):
5 | def setup(self, device):
6 | # Set up the model, so it can be called in `predict`.
7 | self.model = lambda x: x**2
8 |
9 | def decode_request(self, request):
10 | # Convert the request payload to your model input.
11 | return request["input"]
12 |
13 | def predict(self, x):
14 | # Run the model on the input and return the output.
15 | return self.model(x)
16 |
17 | def encode_response(self, output):
18 | # Convert the model output to a response payload.
19 | return {"output": output}
20 |
21 |
22 | class SimpleBatchedAPI(LitAPI):
23 | def setup(self, device) -> None:
24 | self.model = lambda x: x**2
25 |
26 | def decode_request(self, request):
27 | import numpy as np
28 |
29 | return np.asarray(request["input"])
30 |
31 | def predict(self, x):
32 | return self.model(x)
33 |
34 | def encode_response(self, output):
35 | return {"output": output}
36 |
37 |
38 | class SimpleTorchAPI(LitAPI):
39 | def setup(self, device):
40 | # move the model to the correct device
41 | # keep track of the device for moving data accordingly
42 | import torch.nn as nn
43 |
44 | class Linear(nn.Module):
45 | def __init__(self):
46 | super().__init__()
47 | self.linear = nn.Linear(1, 1)
48 | self.linear.weight.data.fill_(2.0)
49 | self.linear.bias.data.fill_(1.0)
50 |
51 | def forward(self, x):
52 | return self.linear(x)
53 |
54 | self.model = Linear().to(device)
55 |
56 | def decode_request(self, request):
57 | import torch
58 |
59 | # get the input and create a 1D tensor on the correct device
60 | content = request["input"]
61 | return torch.tensor([content], device=self.device)
62 |
63 | def predict(self, x):
64 | # the model expects a batch dimension, so create it
65 | return self.model(x[None, :])
66 |
67 | def encode_response(self, output):
68 | # float will take the output value directly onto CPU memory
69 | return {"output": float(output)}
70 |
71 |
72 | class SimpleStreamAPI(LitAPI):
73 | """
74 | Run as:
75 | ```
76 | server = ls.LitServer(SimpleStreamAPI(), stream=True)
77 | server.run(port=8000)
78 | ```
79 | Then, in a new Python session, retrieve the responses as follows:
80 | ```
81 | import requests
82 | url = "http://127.0.0.1:8000/predict"
83 | resp = requests.post(url, json={"input": "Hello world"}, headers=None, stream=True)
84 | for line in resp.iter_content(5000):
85 | if line:
86 | print(line.decode("utf-8"))
87 | ```
88 | """
89 |
90 | def setup(self, device) -> None:
91 | self.model = lambda x, y: f"{x}: {y}"
92 |
93 | def decode_request(self, request):
94 | return request["input"]
95 |
96 | def predict(self, x):
97 | for i in range(3):
98 | yield self.model(i, x)
99 |
100 | def encode_response(self, output_stream):
101 | for output in output_stream:
102 | yield {"output": output}
103 |
--------------------------------------------------------------------------------
/src/litserve/transport/__init__.py:
--------------------------------------------------------------------------------
1 | from .process_transport import MPQueueTransport
2 | from .zmq_transport import ZMQTransport
3 |
4 | __all__ = ["ZMQTransport", "MPQueueTransport"]
5 |
--------------------------------------------------------------------------------
/src/litserve/transport/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Optional
3 |
4 |
5 | class MessageTransport(ABC):
6 | @abstractmethod
7 | def send(self, item: Any, consumer_id: int) -> None:
8 | """Send a message to a consumer in the main process."""
9 | pass
10 |
11 | @abstractmethod
12 | async def areceive(self, timeout: Optional[int] = None, consumer_id: Optional[int] = None) -> dict:
13 | """Receive a message from model workers or any publisher."""
14 | pass
15 |
16 | def close(self) -> None:
17 | """Clean up resources if needed (e.g., sockets, processes)."""
18 | pass
19 |
--------------------------------------------------------------------------------
/src/litserve/transport/factory.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Manager
2 | from typing import Literal, Optional
3 |
4 | from pydantic import BaseModel, Field
5 |
6 | from litserve.transport.process_transport import MPQueueTransport
7 | from litserve.transport.zmq_queue import Broker
8 | from litserve.transport.zmq_transport import ZMQTransport
9 |
10 |
11 | class TransportConfig(BaseModel):
12 | transport_type: Literal["mp", "zmq"] = "mp"
13 | num_consumers: int = Field(1, ge=1)
14 | manager: Optional[Manager] = None
15 | consumer_id: Optional[int] = None
16 | frontend_address: Optional[str] = None
17 | backend_address: Optional[str] = None
18 |
19 |
20 | def _create_zmq_transport(config: TransportConfig):
21 | broker = Broker()
22 | broker.start()
23 | config.frontend_address = broker.frontend_address
24 | config.backend_address = broker.backend_address
25 | return ZMQTransport(config.frontend_address, config.backend_address)
26 |
27 |
28 | def _create_mp_transport(config: TransportConfig):
29 | queues = [config.manager.Queue() for _ in range(config.num_consumers)]
30 | return MPQueueTransport(config.manager, queues)
31 |
32 |
33 | def create_transport_from_config(config: TransportConfig):
34 | if config.transport_type == "mp":
35 | return _create_mp_transport(config)
36 | if config.transport_type == "zmq":
37 | return _create_zmq_transport(config)
38 | raise ValueError(f"Invalid transport type: {config.transport_type}")
39 |
--------------------------------------------------------------------------------
/src/litserve/transport/process_transport.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from contextlib import suppress
3 | from multiprocessing import Manager, Queue
4 | from typing import Any, List, Optional
5 |
6 | from litserve.transport.base import MessageTransport
7 |
8 |
9 | class MPQueueTransport(MessageTransport):
10 | def __init__(self, manager: Manager, queues: List[Queue]):
11 | self._queues = queues
12 | self._closed = False
13 |
14 | def send(self, item: Any, consumer_id: int) -> None:
15 | return self._queues[consumer_id].put(item)
16 |
17 | async def areceive(self, consumer_id: int, timeout: Optional[float] = None, block: bool = True) -> dict:
18 | if self._closed:
19 | raise asyncio.CancelledError("Transport closed")
20 |
21 | actual_timeout = 1 if timeout is None else min(timeout, 1)
22 |
23 | try:
24 | return await asyncio.to_thread(self._queues[consumer_id].get, timeout=actual_timeout, block=True)
25 | except asyncio.CancelledError:
26 | raise
27 | except Exception:
28 | if self._closed:
29 | raise asyncio.CancelledError("Transport closed")
30 | if timeout is not None and timeout <= actual_timeout:
31 | raise
32 | return None
33 |
34 | def close(self) -> None:
35 | # Mark the transport as closed
36 | self._closed = True
37 |
38 | # Put sentinel values in the queues as a backup mechanism
39 | for queue in self._queues:
40 | with suppress(Exception):
41 | queue.put(None)
42 |
43 | def __reduce__(self):
44 | return (MPQueueTransport, (None, self._queues))
45 |
--------------------------------------------------------------------------------
/src/litserve/transport/zmq_queue.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import asyncio
15 | import logging
16 | import multiprocessing
17 | import pickle
18 | import threading
19 | import time
20 | from queue import Empty
21 | from typing import Any, Optional
22 |
23 | import zmq
24 | import zmq.asyncio
25 |
26 | from litserve.utils import generate_random_zmq_address
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | class Broker:
32 | """Message broker that routes messages between producers and consumers."""
33 |
34 | def __init__(self, use_process: bool = False):
35 | self.frontend_address = generate_random_zmq_address()
36 | self.backend_address = generate_random_zmq_address()
37 | self._running = False
38 | self._use_process = use_process
39 | self._worker = None
40 |
41 | def start(self):
42 | """Start the broker in a background thread or process."""
43 | self._running = True
44 |
45 | if self._use_process:
46 | self._worker = multiprocessing.Process(target=self._run)
47 | else:
48 | self._worker = threading.Thread(target=self._run)
49 |
50 | self._worker.daemon = True
51 | self._worker.start()
52 | logger.info(
53 | f"Broker started in {'process' if self._use_process else 'thread'} "
54 | f"on {self.frontend_address} (frontend) and {self.backend_address} (backend)"
55 | )
56 | time.sleep(0.1) # Give the broker time to start
57 |
58 | def _run(self):
59 | """Main broker loop."""
60 | context = zmq.Context()
61 | try:
62 | frontend = context.socket(zmq.XPUB)
63 | frontend.bind(self.frontend_address)
64 |
65 | backend = context.socket(zmq.XSUB)
66 | backend.bind(self.backend_address)
67 |
68 | zmq.proxy(frontend, backend)
69 | except zmq.ZMQError as e:
70 | logger.error(f"Broker error: {e}")
71 | finally:
72 | frontend.close(linger=0)
73 | backend.close(linger=0)
74 | context.term()
75 |
76 | def stop(self):
77 | """Stop the broker."""
78 | self._running = False
79 | if self._worker:
80 | self._worker.join()
81 |
82 |
83 | class Producer:
84 | """Producer class for sending messages to consumers."""
85 |
86 | def __init__(self, address: str = None):
87 | self._context = zmq.Context()
88 | self._socket = self._context.socket(zmq.PUB)
89 | self._socket.connect(address)
90 |
91 | def wait_for_subscribers(self, timeout: float = 1.0) -> bool:
92 | """Wait for at least one subscriber to be ready.
93 |
94 | Args:
95 | timeout: Maximum time to wait in seconds
96 |
97 | Returns:
98 | bool: True if subscribers are ready, False if timeout occurred
99 |
100 | """
101 | start_time = time.time()
102 | while time.time() - start_time < timeout:
103 | # Send a ping message to consumer 0 (special system messages)
104 | try:
105 | self._socket.send(b"0|__ping__", zmq.NOBLOCK)
106 | time.sleep(0.1) # Give time for subscription to propagate
107 | return True
108 | except zmq.ZMQError:
109 | continue
110 | return False
111 |
112 | def put(self, item: Any, consumer_id: int) -> None:
113 | """Send an item to a specific consumer."""
114 | try:
115 | pickled_item = pickle.dumps(item)
116 | message = f"{consumer_id}|".encode() + pickled_item
117 | self._socket.send(message)
118 | except zmq.ZMQError as e:
119 | logger.error(f"Error sending item: {e}")
120 | raise
121 | except pickle.PickleError as e:
122 | logger.error(f"Error serializing item: {e}")
123 | raise
124 |
125 | def close(self) -> None:
126 | """Clean up resources."""
127 | if self._socket:
128 | self._socket.close(linger=0)
129 | if self._context:
130 | self._context.term()
131 |
132 |
133 | class BaseConsumer:
134 | """Base class for consumers."""
135 |
136 | def __init__(self, consumer_id: int, address: str):
137 | self.consumer_id = consumer_id
138 | self.address = address
139 | self._context = None
140 | self._socket = None
141 | self._setup_socket()
142 |
143 | def _setup_socket(self):
144 | """Setup ZMQ socket - to be implemented by subclasses"""
145 | raise NotImplementedError
146 |
147 | def _parse_message(self, message: bytes) -> Any:
148 | """Parse a message received from ZMQ."""
149 | try:
150 | consumer_id, pickled_data = message.split(b"|", 1)
151 | return pickle.loads(pickled_data)
152 | except pickle.PickleError as e:
153 | logger.error(f"Error deserializing message: {e}")
154 | raise
155 |
156 | def close(self) -> None:
157 | """Clean up resources."""
158 | if self._socket:
159 | self._socket.close(linger=0)
160 | if self._context:
161 | self._context.term()
162 |
163 |
164 | class AsyncConsumer(BaseConsumer):
165 | """Async consumer class for receiving messages using asyncio."""
166 |
167 | def _setup_socket(self):
168 | self._context = zmq.asyncio.Context()
169 | self._socket = self._context.socket(zmq.SUB)
170 | self._socket.connect(self.address)
171 | self._socket.setsockopt_string(zmq.SUBSCRIBE, str(self.consumer_id))
172 |
173 | async def get(self, timeout: Optional[float] = None) -> Any:
174 | """Get an item from the queue asynchronously."""
175 | try:
176 | if timeout is not None:
177 | message = await asyncio.wait_for(self._socket.recv(), timeout)
178 | else:
179 | message = await self._socket.recv()
180 |
181 | return self._parse_message(message)
182 | except asyncio.TimeoutError:
183 | raise Empty
184 | except zmq.ZMQError:
185 | raise Empty
186 |
187 | def close(self) -> None:
188 | """Clean up resources asynchronously."""
189 | if self._socket:
190 | self._socket.close(linger=0)
191 | if self._context:
192 | self._context.term()
193 |
--------------------------------------------------------------------------------
/src/litserve/transport/zmq_transport.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Literal, Optional, Union
2 |
3 | import zmq
4 |
5 | from litserve.transport.base import MessageTransport
6 | from litserve.transport.zmq_queue import AsyncConsumer, Producer
7 |
8 |
9 | class ZMQTransport(MessageTransport):
10 | def __init__(self, backend_address: str, frontend_address):
11 | self.backend_address = backend_address
12 | self.frontend_address = frontend_address
13 | self._zmq: Union[Producer, AsyncConsumer, None] = None
14 |
15 | def setup(self, operation: Literal[zmq.SUB, zmq.PUB], consumer_id: Optional[int] = None) -> None:
16 | """Must be called in the subprocess to setup the ZMQ transport."""
17 | if operation == zmq.PUB:
18 | self._zmq = Producer(address=self.backend_address)
19 | self._zmq.wait_for_subscribers()
20 | elif operation == zmq.SUB:
21 | self._zmq = AsyncConsumer(consumer_id=consumer_id, address=self.frontend_address)
22 | else:
23 | ValueError(f"Invalid operation {operation}")
24 |
25 | def send(self, item: Any, consumer_id: int) -> None:
26 | if self._zmq is None:
27 | self.setup(zmq.PUB)
28 | return self._zmq.put(item, consumer_id)
29 |
30 | async def areceive(self, consumer_id: Optional[int] = None, timeout=None) -> dict:
31 | if self._zmq is None:
32 | self.setup(zmq.SUB, consumer_id)
33 | return await self._zmq.get(timeout=timeout)
34 |
35 | def close(self) -> None:
36 | if self._zmq:
37 | self._zmq.close()
38 | else:
39 | raise ValueError("ZMQ not initialized, make sure ZMQTransport.setup() is called.")
40 |
41 | def __reduce__(self):
42 | return ZMQTransport, (self.backend_address, self.frontend_address)
43 |
--------------------------------------------------------------------------------
/src/litserve/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import asyncio
15 | import dataclasses
16 | import logging
17 | import os
18 | import pdb
19 | import pickle
20 | import sys
21 | import uuid
22 | from contextlib import contextmanager
23 | from enum import Enum
24 | from typing import TYPE_CHECKING, Any, AsyncIterator, TextIO, Union
25 |
26 | from fastapi import HTTPException
27 |
28 | if TYPE_CHECKING:
29 | from litserve.server import LitServer
30 |
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 | class LitAPIStatus:
35 | OK = "OK"
36 | ERROR = "ERROR"
37 | FINISH_STREAMING = "FINISH_STREAMING"
38 |
39 |
40 | class LoopResponseType(Enum):
41 | STREAMING = "STREAMING"
42 | REGULAR = "REGULAR"
43 |
44 |
45 | class PickleableHTTPException(HTTPException):
46 | @staticmethod
47 | def from_exception(exc: HTTPException):
48 | status_code = exc.status_code
49 | detail = exc.detail
50 | return PickleableHTTPException(status_code, detail)
51 |
52 | def __reduce__(self):
53 | return (HTTPException, (self.status_code, self.detail))
54 |
55 |
56 | def dump_exception(exception):
57 | if isinstance(exception, HTTPException):
58 | exception = PickleableHTTPException.from_exception(exception)
59 | return pickle.dumps(exception)
60 |
61 |
62 | async def azip(*async_iterables):
63 | iterators = [ait.__aiter__() for ait in async_iterables]
64 | while True:
65 | results = await asyncio.gather(*(ait.__anext__() for ait in iterators), return_exceptions=True)
66 | if any(isinstance(result, StopAsyncIteration) for result in results):
67 | break
68 | yield tuple(results)
69 |
70 |
71 | @contextmanager
72 | def wrap_litserve_start(server: "LitServer"):
73 | """Pytest utility to start the server in a context manager."""
74 | server.app.response_queue_id = 0
75 | for lit_api in server.litapi_connector:
76 | if lit_api.spec:
77 | lit_api.spec.response_queue_id = 0
78 |
79 | manager = server._init_manager(1)
80 | processes = []
81 | for lit_api in server.litapi_connector:
82 | processes.extend(server.launch_inference_worker(lit_api))
83 | server._prepare_app_run(server.app)
84 | try:
85 | yield server
86 | finally:
87 | # First close the transport to signal to the response_queue_to_buffer task that it should stop
88 | server._transport.close()
89 | for p in processes:
90 | p.terminate()
91 | p.join()
92 | manager.shutdown()
93 |
94 |
95 | async def call_after_stream(streamer: AsyncIterator, callback, *args, **kwargs):
96 | try:
97 | async for item in streamer:
98 | yield item
99 | except Exception as e:
100 | logger.exception(f"Error in streamer: {e}")
101 | finally:
102 | callback(*args, **kwargs)
103 |
104 |
105 | @dataclasses.dataclass
106 | class WorkerSetupStatus:
107 | STARTING: str = "starting"
108 | READY: str = "ready"
109 | ERROR: str = "error"
110 | FINISHED: str = "finished"
111 |
112 |
113 | def _get_default_handler(stream, format):
114 | handler = logging.StreamHandler(stream)
115 | formatter = logging.Formatter(format)
116 | handler.setFormatter(formatter)
117 | return handler
118 |
119 |
120 | def configure_logging(
121 | level: Union[str, int] = logging.INFO,
122 | format: str = "%(asctime)s - %(processName)s[%(process)d] - %(name)s - %(levelname)s - %(message)s",
123 | stream: TextIO = sys.stdout,
124 | use_rich: bool = False,
125 | ):
126 | """Configure logging for the entire library with sensible defaults.
127 |
128 | Args:
129 | level (int): Logging level (default: logging.INFO)
130 | format (str): Log message format string
131 | stream (file-like): Output stream for logs
132 | use_rich (bool): Makes the logs more readable by using rich, useful for debugging. Defaults to False.
133 |
134 | """
135 | if isinstance(level, str):
136 | level = level.upper()
137 | level = getattr(logging, level)
138 |
139 | # Clear any existing handlers to prevent duplicates
140 | library_logger = logging.getLogger("litserve")
141 | for handler in library_logger.handlers[:]:
142 | library_logger.removeHandler(handler)
143 |
144 | if use_rich:
145 | try:
146 | from rich.logging import RichHandler
147 | from rich.traceback import install
148 |
149 | install(show_locals=True)
150 | handler = RichHandler(rich_tracebacks=True, show_time=True, show_path=True)
151 | except ImportError:
152 | logger.warning("Rich is not installed, using default logging")
153 | handler = _get_default_handler(stream, format)
154 | else:
155 | handler = _get_default_handler(stream, format)
156 |
157 | # Configure library logger
158 | library_logger.setLevel(level)
159 | library_logger.addHandler(handler)
160 | library_logger.propagate = False
161 |
162 |
163 | def set_log_level(level):
164 | """Allow users to set the global logging level for the library."""
165 | logging.getLogger("litserve").setLevel(level)
166 |
167 |
168 | def add_log_handler(handler):
169 | """Allow users to add custom log handlers.
170 |
171 | Example usage:
172 | file_handler = logging.FileHandler('library_logs.log')
173 | add_log_handler(file_handler)
174 |
175 | """
176 | logging.getLogger("litserve").addHandler(handler)
177 |
178 |
179 | def generate_random_zmq_address(temp_dir="/tmp"):
180 | """Generate a random IPC address in the /tmp directory.
181 |
182 | Ensures the address is unique.
183 | Returns:
184 | str: A random IPC address suitable for ZeroMQ.
185 |
186 | """
187 | unique_name = f"zmq-{uuid.uuid4().hex}.ipc"
188 | ipc_path = os.path.join(temp_dir, unique_name)
189 | return f"ipc://{ipc_path}"
190 |
191 |
192 | class ForkedPdb(pdb.Pdb):
193 | # Borrowed from - https://github.com/Lightning-AI/forked-pdb
194 | """
195 | PDB Subclass for debugging multi-processed code
196 | Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess
197 | """
198 |
199 | def interaction(self, *args: Any, **kwargs: Any) -> None:
200 | _stdin = sys.stdin
201 | try:
202 | sys.stdin = open("/dev/stdin") # noqa: SIM115
203 | pdb.Pdb.interaction(self, *args, **kwargs)
204 | finally:
205 | sys.stdin = _stdin
206 |
207 |
208 | def set_trace():
209 | """Set a tracepoint in the code."""
210 | ForkedPdb().set_trace()
211 |
212 |
213 | def set_trace_if_debug(debug_env_var="LITSERVE_DEBUG", debug_env_var_value="1"):
214 | """Set a tracepoint in the code if the environment variable LITSERVE_DEBUG is set."""
215 | if os.environ.get(debug_env_var) == debug_env_var_value:
216 | set_trace()
217 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/e2e/default_api.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import litserve as ls
15 |
16 | if __name__ == "__main__":
17 | api = ls.test_examples.SimpleLitAPI()
18 | server = ls.LitServer(api)
19 | server.run(port=8000)
20 |
--------------------------------------------------------------------------------
/tests/e2e/default_async_streaming.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import litserve as ls
15 |
16 |
17 | class AsyncAPI(ls.LitAPI):
18 | def setup(self, device) -> None:
19 | self.model = lambda x: x
20 |
21 | async def decode_request(self, request):
22 | return request["input"]
23 |
24 | async def predict(self, x):
25 | for i in range(10):
26 | yield self.model(i)
27 |
28 | async def encode_response(self, output):
29 | for out in output:
30 | yield {"output": out}
31 |
32 |
33 | if __name__ == "__main__":
34 | api = AsyncAPI(enable_async=True)
35 | server = ls.LitServer(api, stream=True)
36 | server.run(port=8000)
37 |
--------------------------------------------------------------------------------
/tests/e2e/default_batched_streaming.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import numpy as np
15 |
16 | import litserve as ls
17 |
18 |
19 | class SimpleStreamAPI(ls.LitAPI):
20 | def setup(self, device) -> None:
21 | self.model = lambda x, y: x * y
22 |
23 | def decode_request(self, request):
24 | return np.asarray(request["input"])
25 |
26 | def predict(self, x):
27 | for i in range(10):
28 | yield self.model(x, i)
29 |
30 | def encode_response(self, output_stream):
31 | for outputs in output_stream:
32 | yield [{"output": output} for output in outputs]
33 |
34 |
35 | if __name__ == "__main__":
36 | server = ls.LitServer(SimpleStreamAPI(), stream=True, max_batch_size=4, batch_timeout=0.2, fast_queue=True)
37 | server.run(port=8000)
38 |
--------------------------------------------------------------------------------
/tests/e2e/default_batching.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import litserve as ls
15 |
16 | if __name__ == "__main__":
17 | api = ls.test_examples.SimpleBatchedAPI()
18 | server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.05)
19 | server.run(port=8000)
20 |
--------------------------------------------------------------------------------
/tests/e2e/default_openai_embedding_spec.py:
--------------------------------------------------------------------------------
1 | import litserve as ls
2 | from litserve import OpenAIEmbeddingSpec
3 | from litserve.test_examples.openai_embedding_spec_example import TestEmbedAPI
4 |
5 | if __name__ == "__main__":
6 | server = ls.LitServer(TestEmbedAPI(), spec=OpenAIEmbeddingSpec(), fast_queue=True)
7 | server.run()
8 |
--------------------------------------------------------------------------------
/tests/e2e/default_openai_with_batching.py:
--------------------------------------------------------------------------------
1 | import litserve as ls
2 | from litserve.test_examples.openai_spec_example import OpenAIBatchContext
3 |
4 | if __name__ == "__main__":
5 | api = OpenAIBatchContext()
6 | server = ls.LitServer(api, spec=ls.OpenAISpec(), max_batch_size=2, batch_timeout=0.5, fast_queue=True)
7 | server.run(port=8000)
8 |
--------------------------------------------------------------------------------
/tests/e2e/default_openaispec.py:
--------------------------------------------------------------------------------
1 | import litserve as ls
2 | from litserve import OpenAISpec
3 | from litserve.test_examples.openai_spec_example import TestAPI
4 |
5 | if __name__ == "__main__":
6 | server = ls.LitServer(TestAPI(), spec=OpenAISpec())
7 | server.run()
8 |
--------------------------------------------------------------------------------
/tests/e2e/default_openaispec_response_format.py:
--------------------------------------------------------------------------------
1 | import litserve as ls
2 | from litserve import OpenAISpec
3 | from litserve.test_examples.openai_spec_example import TestAPIWithStructuredOutput
4 |
5 | if __name__ == "__main__":
6 | server = ls.LitServer(TestAPIWithStructuredOutput(), spec=OpenAISpec(), fast_queue=True)
7 | server.run()
8 |
--------------------------------------------------------------------------------
/tests/e2e/default_openaispec_tools.py:
--------------------------------------------------------------------------------
1 | import litserve as ls
2 | from litserve import OpenAISpec
3 | from litserve.specs.openai import ChatMessage
4 | from litserve.test_examples.openai_spec_example import TestAPI
5 |
6 |
7 | class TestAPIWithToolCalls(TestAPI):
8 | def encode_response(self, output):
9 | yield ChatMessage(
10 | role="assistant",
11 | content="",
12 | tool_calls=[
13 | {
14 | "id": "call_abc123",
15 | "type": "function",
16 | "function": {"name": "get_current_weather", "arguments": '{\n"location": "Boston, MA"\n}'},
17 | }
18 | ],
19 | )
20 |
21 |
22 | if __name__ == "__main__":
23 | server = ls.LitServer(TestAPIWithToolCalls(), spec=OpenAISpec())
24 | server.run()
25 |
--------------------------------------------------------------------------------
/tests/e2e/default_single_streaming.py:
--------------------------------------------------------------------------------
1 | from litserve import LitAPI, LitServer
2 |
3 |
4 | class SimpleStreamingAPI(LitAPI):
5 | def setup(self, device) -> None:
6 | self.model = lambda x, y: x * y
7 |
8 | def decode_request(self, request):
9 | return request["input"]
10 |
11 | def predict(self, x):
12 | for i in range(1, 4):
13 | yield self.model(i, x)
14 |
15 | def encode_response(self, output_stream):
16 | for output in output_stream:
17 | yield {"output": output}
18 |
19 |
20 | if __name__ == "__main__":
21 | api = SimpleStreamingAPI()
22 | server = LitServer(api, stream=True)
23 | server.run(port=8000)
24 |
--------------------------------------------------------------------------------
/tests/e2e/default_spec.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import litserve as ls
15 | from litserve.specs.openai import OpenAISpec
16 | from litserve.test_examples.openai_spec_example import TestAPI
17 |
18 | if __name__ == "__main__":
19 | server = ls.LitServer(TestAPI(), spec=OpenAISpec())
20 | server.run(port=8000)
21 |
--------------------------------------------------------------------------------
/tests/e2e/openai_embedding_with_batching.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import litserve as ls
4 |
5 |
6 | class EmbeddingsAPI(ls.LitAPI):
7 | def setup(self, device):
8 | def model(x):
9 | return np.random.rand(len(x), 768)
10 |
11 | self.model = model
12 |
13 | def predict(self, inputs):
14 | return self.model(inputs)
15 |
16 |
17 | if __name__ == "__main__":
18 | api = EmbeddingsAPI(max_batch_size=10, batch_timeout=2)
19 | server = ls.LitServer(api, spec=ls.OpenAIEmbeddingSpec())
20 | server.run(port=8000)
21 |
--------------------------------------------------------------------------------
/tests/minimal_run.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import json
15 | import subprocess
16 | import time
17 | import urllib.request
18 |
19 | import psutil
20 |
21 |
22 | def main():
23 | process = subprocess.Popen(
24 | ["python", "tests/simple_server.py"],
25 | )
26 | print("Waiting for server to start...")
27 | time.sleep(10)
28 | try:
29 | url = "http://127.0.0.1:8000/predict"
30 | data = json.dumps({"input": 4.0}).encode("utf-8")
31 | headers = {"Content-Type": "application/json"}
32 | request = urllib.request.Request(url, data=data, headers=headers, method="POST")
33 | response = urllib.request.urlopen(request)
34 | status_code = response.getcode()
35 | assert status_code == 200
36 | except Exception:
37 | raise
38 |
39 | finally:
40 | parent = psutil.Process(process.pid)
41 | for child in parent.children(recursive=True):
42 | child.kill()
43 | process.kill()
44 |
45 |
46 | if __name__ == "__main__":
47 | main()
48 |
--------------------------------------------------------------------------------
/tests/parity_fastapi/benchmark.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import concurrent.futures
3 | import random
4 | import time
5 |
6 | import numpy as np
7 | import requests
8 | import torch
9 | from PIL import Image
10 |
11 | device = "cuda" if torch.cuda.is_available() else "cpu"
12 | device = "mps" if torch.backends.mps.is_available() else device
13 |
14 | rand_mat = np.random.rand(2, 224, 224, 3) * 255
15 | Image.fromarray(rand_mat[0].astype("uint8")).convert("RGB").save("image1.jpg")
16 | Image.fromarray(rand_mat[1].astype("uint8")).convert("RGB").save("image2.jpg")
17 |
18 | SERVER_URL = "http://127.0.0.1:{}/predict"
19 |
20 | payloads = []
21 | for file in ["image1.jpg", "image2.jpg"]:
22 | with open(file, "rb") as image_file:
23 | encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
24 | payloads.append(encoded_string)
25 |
26 |
27 | def send_request(port):
28 | """Function to send a single request and measure the response time."""
29 | url = SERVER_URL.format(port)
30 | payload = {"image_data": random.choice(payloads)}
31 | start_time = time.time()
32 | response = requests.post(url, json=payload)
33 | end_time = time.time()
34 | return end_time - start_time, response.status_code
35 |
36 |
37 | def benchmark(num_requests=100, concurrency_level=100, port=8000):
38 | """Benchmark the ML server."""
39 |
40 | start_benchmark_time = time.time() # Start benchmark timing
41 | with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency_level) as executor:
42 | futures = [executor.submit(send_request, port) for _ in range(num_requests)]
43 | response_times = []
44 | status_codes = []
45 |
46 | for future in concurrent.futures.as_completed(futures):
47 | response_time, status_code = future.result()
48 | response_times.append(response_time)
49 | status_codes.append(status_code)
50 |
51 | end_benchmark_time = time.time() # End benchmark timing
52 | total_benchmark_time = end_benchmark_time - start_benchmark_time # Time in seconds
53 |
54 | # Analysis
55 | total_time = sum(response_times) # Time in seconds
56 | avg_time = total_time / num_requests # Time in seconds
57 | success_rate = status_codes.count(200) / num_requests * 100
58 | rps = num_requests / total_benchmark_time # Requests per second
59 |
60 | # Create a dictionary with the metrics
61 | metrics = {
62 | "Total Requests": num_requests,
63 | "Concurrency Level": concurrency_level,
64 | "Total Benchmark Time (seconds)": total_benchmark_time,
65 | "Average Response Time (ms)": avg_time * 1000,
66 | "Success Rate (%)": success_rate,
67 | "Requests Per Second (RPS)": rps,
68 | }
69 |
70 | # Print the metrics
71 | for key, value in metrics.items():
72 | print(f"{key}: {value}")
73 | print("-" * 50)
74 |
75 | return metrics
76 |
77 |
78 | def run_bench(conf: dict, num_samples: int, port: int):
79 | num_requests = conf[device]["num_requests"]
80 |
81 | results = []
82 | for _ in range(num_samples):
83 | metric = benchmark(num_requests=num_requests, concurrency_level=num_requests, port=port)
84 | results.append(metric)
85 | return results[2:] # skip warmup step
86 |
--------------------------------------------------------------------------------
/tests/parity_fastapi/fastapi-server.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import io
3 |
4 | import PIL
5 | import torch
6 | import torchvision
7 | from fastapi import FastAPI, HTTPException
8 | from jsonargparse import CLI
9 | from pydantic import BaseModel
10 |
11 | # Set float32 matrix multiplication precision if GPU is available and capable
12 | if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0):
13 | torch.set_float32_matmul_precision("high")
14 |
15 | app = FastAPI()
16 |
17 |
18 | class ImageData(BaseModel):
19 | image_data: str
20 |
21 |
22 | class ImageClassifierAPI:
23 | def __init__(self, device):
24 | self.device = device
25 | weights = torchvision.models.ResNet18_Weights.DEFAULT
26 | self.image_processing = weights.transforms()
27 | self.model = torchvision.models.resnet18(weights=None).eval().to(device)
28 |
29 | def process_image(self, image_data):
30 | image = base64.b64decode(image_data)
31 | pil_image = PIL.Image.open(io.BytesIO(image)).convert("RGB")
32 | processed_image = self.image_processing(pil_image)
33 | return processed_image.unsqueeze(0).to(self.device) # Add batch dimension
34 |
35 | def predict(self, x):
36 | with torch.inference_mode():
37 | outputs = self.model(x)
38 | _, predictions = torch.max(outputs, 1)
39 | return predictions.item()
40 |
41 |
42 | device = "cuda" if torch.cuda.is_available() else "cpu"
43 | if torch.backends.mps.is_available():
44 | device = "mps"
45 | api = ImageClassifierAPI(device)
46 |
47 |
48 | @app.get("/health")
49 | async def health():
50 | return {"status": "ok"}
51 |
52 |
53 | @app.post("/predict")
54 | async def predict(image_data: ImageData):
55 | try:
56 | processed_image = api.process_image(image_data.image_data)
57 | prediction = api.predict(processed_image)
58 | return {"output": prediction}
59 | except Exception as e:
60 | raise HTTPException(status_code=500, detail=str(e))
61 |
62 |
63 | def main():
64 | import uvicorn
65 |
66 | uvicorn.run(app, host="0.0.0.0", port=8001, log_level="warning")
67 |
68 |
69 | if __name__ == "__main__":
70 | CLI(main)
71 |
--------------------------------------------------------------------------------
/tests/parity_fastapi/ls-server.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import io
3 | import os
4 | from concurrent.futures import ThreadPoolExecutor
5 |
6 | import PIL
7 | import torch
8 | import torchvision
9 |
10 | import litserve as ls
11 |
12 | device = "cuda" if torch.cuda.is_available() else "cpu"
13 | device = "mps" if torch.backends.mps.is_available() else device
14 | conf = {
15 | "cuda": {"batch_size": 16, "workers_per_device": 2},
16 | "cpu": {"batch_size": 8, "workers_per_device": 2},
17 | "mps": {"batch_size": 8, "workers_per_device": 2},
18 | }
19 |
20 | # Set float32 matrix multiplication precision if GPU is available and capable
21 | if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0):
22 | torch.set_float32_matmul_precision("high")
23 |
24 |
25 | class ImageClassifierAPI(ls.LitAPI):
26 | def setup(self, device):
27 | print(device)
28 | weights = torchvision.models.ResNet18_Weights.DEFAULT
29 | self.image_processing = weights.transforms()
30 | self.model = torchvision.models.resnet18(weights=None).eval().to(device)
31 | self.pool = ThreadPoolExecutor(os.cpu_count())
32 |
33 | def decode_request(self, request):
34 | return request["image_data"]
35 |
36 | def batch(self, image_data_list):
37 | def process_image(image_data):
38 | image = base64.b64decode(image_data)
39 | pil_image = PIL.Image.open(io.BytesIO(image)).convert("RGB")
40 | return self.image_processing(pil_image)
41 |
42 | inputs = list(self.pool.map(process_image, image_data_list))
43 | return torch.stack(inputs).to(self.device)
44 |
45 | def predict(self, x):
46 | with torch.inference_mode():
47 | outputs = self.model(x)
48 | _, predictions = torch.max(outputs, 1)
49 | return predictions
50 |
51 | def unbatch(self, outputs):
52 | return outputs.tolist()
53 |
54 | def encode_response(self, output):
55 | return {"output": output}
56 |
57 |
58 | def main(batch_size: int, workers_per_device: int):
59 | print(locals())
60 | api = ImageClassifierAPI()
61 | server = ls.LitServer(
62 | api,
63 | max_batch_size=batch_size,
64 | batch_timeout=0.01,
65 | timeout=10,
66 | workers_per_device=workers_per_device,
67 | fast_queue=True,
68 | )
69 | server.run(port=8000, log_level="warning")
70 |
71 |
72 | if __name__ == "__main__":
73 | main(**conf[device])
74 |
--------------------------------------------------------------------------------
/tests/parity_fastapi/main.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import time
3 | from functools import wraps
4 |
5 | import psutil
6 | import requests
7 | import torch
8 | from benchmark import run_bench
9 |
10 | CONF = {
11 | "cpu": {"num_requests": 50},
12 | "mps": {"num_requests": 50},
13 | "cuda": {"num_requests": 100},
14 | }
15 |
16 | device = "cuda" if torch.cuda.is_available() else "cpu"
17 | device = "mps" if torch.backends.mps.is_available() else device
18 |
19 | DIFF_FACTOR = {
20 | "cpu": 1,
21 | "cuda": 1.2,
22 | "mps": 1,
23 | }
24 |
25 |
26 | def run_python_script(filename):
27 | def decorator(test_fn):
28 | @wraps(test_fn)
29 | def wrapper(*args, **kwargs):
30 | process = subprocess.Popen(
31 | ["python", filename],
32 | )
33 | print("Waiting for server to start...")
34 | time.sleep(10)
35 |
36 | try:
37 | return test_fn(*args, **kwargs)
38 | except Exception:
39 | raise
40 | finally:
41 | print("Killing the server")
42 | parent = psutil.Process(process.pid)
43 | for child in parent.children(recursive=True):
44 | child.kill()
45 | process.kill()
46 |
47 | return wrapper
48 |
49 | return decorator
50 |
51 |
52 | def try_health(port):
53 | for i in range(10):
54 | try:
55 | response = requests.get(f"http://127.0.0.1:{port}/health")
56 | if response.status_code == 200:
57 | return
58 | except Exception:
59 | pass
60 |
61 |
62 | @run_python_script("tests/parity_fastapi/fastapi-server.py")
63 | def run_fastapi_benchmark(num_samples):
64 | port = 8001
65 | try_health(port)
66 | return run_bench(CONF, num_samples, port)
67 |
68 |
69 | @run_python_script("tests/parity_fastapi/ls-server.py")
70 | def run_litserve_benchmark(num_samples):
71 | port = 8000
72 | try_health(port)
73 | return run_bench(CONF, num_samples, port)
74 |
75 |
76 | def mean(lst):
77 | return sum(lst) / len(lst)
78 |
79 |
80 | def main():
81 | key = "Requests Per Second (RPS)"
82 | num_samples = 12
83 | print("Running FastAPI benchmark")
84 | fastapi_metrics = run_fastapi_benchmark(num_samples=num_samples)
85 | print("\n\n" + "=" * 50 + "\n\n")
86 | print("Running LitServe benchmark")
87 | ls_metrics = run_litserve_benchmark(num_samples=num_samples)
88 | fastapi_throughput = mean([e[key] for e in fastapi_metrics])
89 | ls_throughput = mean([e[key] for e in ls_metrics])
90 | factor = DIFF_FACTOR[device]
91 | msg = f"LitServe should have higher throughput than FastAPI on {device}. {ls_throughput} vs {fastapi_throughput}"
92 | assert ls_throughput > fastapi_throughput * factor, msg
93 | print(f"{ls_throughput} vs {fastapi_throughput}")
94 |
95 |
96 | if __name__ == "__main__":
97 | main()
98 |
--------------------------------------------------------------------------------
/tests/perf_test/bert/benchmark.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | import requests
5 | from tenacity import retry, stop_after_attempt
6 | from utils import benchmark
7 |
8 | # Configuration
9 | SERVER_URL = "http://0.0.0.0:8000/predict"
10 | MAX_SPEED = 390 # Nvidia 3090
11 |
12 | session = requests.Session()
13 |
14 |
15 | def get_average_throughput(num_requests=100, num_samples=10):
16 | key = "Requests Per Second (RPS)"
17 | latency_key = "Latency per Request (ms)"
18 | metric = 0
19 | latency = 0
20 |
21 | # warmup
22 | benchmark(num_requests=50, concurrency_level=10)
23 | for i in range(num_samples):
24 | bnmk = benchmark(num_requests=num_requests, concurrency_level=num_requests)
25 | metric += bnmk[key]
26 | latency += bnmk[latency_key]
27 | avg = metric / num_samples
28 | print("avg RPS:", avg)
29 | print("avg latency:", latency / num_samples)
30 | return avg
31 |
32 |
33 | @retry(stop=stop_after_attempt(10))
34 | def main():
35 | for i in range(10):
36 | try:
37 | resp = requests.get("http://localhost:8000/health")
38 | if resp.status_code == 200:
39 | break
40 | except requests.exceptions.ConnectionError as e:
41 | logging.error(f"Error connecting to server: {e}")
42 | time.sleep(10)
43 |
44 | rps = get_average_throughput(100, num_samples=10)
45 | assert rps >= MAX_SPEED, f"Expected RPS to be greater than {MAX_SPEED}, got {rps}"
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
50 |
--------------------------------------------------------------------------------
/tests/perf_test/bert/data.py:
--------------------------------------------------------------------------------
1 | phrases = [
2 | "In the midst of a bustling city, amidst the constant hum of traffic and the chatter of countless conversations, "
3 | "there exists a serene park where people come to escape the chaos. Children play on the swings, their laughter "
4 | "echoing through the air, while adults stroll along the winding paths, lost in thought. The trees, tall and"
5 | " majestic, provide a canopy of shade, and the flowers bloom in a riot of colors, adding to the park's charm."
6 | " It's a place where time seems to slow down, offering a moment of peace and reflection in an otherwise hectic"
7 | " world.",
8 | "As the sun sets over the horizon, painting the sky in hues of orange, pink, and purple, a sense of calm descends"
9 | " over the landscape. The day has been long and filled with activity, but now, in this magical hour, everything "
10 | "feels different. The birds return to their nests, their evening songs a lullaby to the world. The gentle breeze "
11 | "carries the scent of blooming jasmine, and the stars begin to twinkle in the darkening sky. It's a time for "
12 | "quiet contemplation, for appreciating the beauty of nature, and for feeling a deep connection to the universe.",
13 | "On a remote island, far away from the noise and pollution of modern life, there is a hidden cove where "
14 | "crystal-clear waters lap gently against the shore. The beach, covered in soft, white sand, is a paradise for "
15 | "those seeking solitude and tranquility. Palm trees sway in the breeze, their fronds rustling softly, while "
16 | "the sun casts a warm, golden glow over everything. Here, one can forget the worries of the world and simply "
17 | "exist in the moment, surrounded by the natural beauty of the island and the soothing sounds of the ocean.",
18 | "In an ancient forest, where the trees have stood for centuries, there is a sense of timelessness that"
19 | " envelops everything. The air is cool and crisp, filled with the earthy scent of moss and fallen leaves. Sunlight"
20 | " filters through the dense canopy, creating dappled patterns on the forest floor. Birds call to one another, and"
21 | " small animals scurry through the underbrush. It's a place where one can feel the weight of history, where the "
22 | "presence of the past is almost palpable. Walking through this forest is like stepping back in time, to a world "
23 | "untouched by human hands.",
24 | "At the edge of a vast desert, where the dunes stretch out as far as the eye can see, there is a small oasis "
25 | "that offers a respite from the harsh conditions. A cluster of palm trees provides shade, and a clear, cool spring"
26 | " bubbles up from the ground, a source of life in an otherwise barren landscape. Travelers who come across this "
27 | "oasis are greeted with the sight of lush greenery and the sound of birdsong. It's a place of refuge and renewal,"
28 | " where one can rest and recharge before continuing on their journey through the endless sands.",
29 | "High in the mountains, where the air is thin and the landscape is rugged, there is a hidden valley that remains"
30 | " largely untouched by human activity. The valley is a haven for wildlife, with streams that flow with clear,"
31 | " cold water and meadows filled with wildflowers. The surrounding peaks, covered in snow even in the summer, "
32 | "stand as silent sentinels. It's a place where one can feel a profound sense of solitude and connection to nature."
33 | " The beauty of the valley, with its pristine environment and abundant life, is a reminder of the importance of"
34 | " preserving wild places.",
35 | "On a quiet country road, far from the bustling cities and noisy highways, there is a small farmhouse surrounded"
36 | " by fields of golden wheat. The farmhouse, with its weathered wooden walls and cozy interior, is a place of warmth"
37 | " and hospitality. The fields, swaying gently in the breeze, are a testament to the hard work and dedication of "
38 | "the farmers who tend them. In the evenings, the sky is filled with stars, and the only sounds are the chirping of"
39 | " crickets and the distant hoot of an owl. It's a place where one can find peace and simplicity.",
40 | "In a quaint village, nestled in the rolling hills of the countryside, life moves at a slower pace. The cobblestone"
41 | " streets are lined with charming cottages, each with its own garden bursting with flowers. The village "
42 | "square is the heart of the community, where residents gather to catch up on news and enjoy each other's company."
43 | " There's a timeless quality to the village, where traditions are upheld, and everyone knows their neighbors. "
44 | "It's a place where one can experience the joys of small-town living, with its close-knit community and strong "
45 | "sense of belonging.",
46 | "By the side of a tranquil lake, surrounded by dense forests and towering mountains, there is a small cabin that "
47 | "offers a perfect retreat from the hustle and bustle of everyday life. The cabin, with its rustic charm and cozy "
48 | "interior, is a place to unwind and relax. The lake, calm and mirror-like, reflects the beauty of the surrounding "
49 | "landscape, creating a sense of peace and serenity. It's a place where one can reconnect with nature, spend quiet"
50 | " moments fishing or kayaking, and enjoy the simple pleasures of life in a beautiful, natural setting.",
51 | ]
52 |
--------------------------------------------------------------------------------
/tests/perf_test/bert/run_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Function to clean up server process
4 | cleanup() {
5 | pkill -f "python tests/perf_test/bert/server.py"
6 | }
7 |
8 | # Trap script exit to run cleanup
9 | trap cleanup EXIT
10 |
11 | # Start the server in the background and capture its PID
12 | python tests/perf_test/bert/server.py &
13 | SERVER_PID=$!
14 |
15 | echo "Server started with PID $SERVER_PID"
16 |
17 | # Run your benchmark script
18 | echo "Preparing to run benchmark.py..."
19 |
20 | export PYTHONPATH=$PWD && python tests/perf_test/bert/benchmark.py
21 |
22 | # Check if benchmark.py exited successfully
23 | if [ $? -ne 0 ]; then
24 | echo "benchmark.py failed to run successfully."
25 | exit 1
26 | else
27 | echo "benchmark.py ran successfully."
28 | fi
29 |
--------------------------------------------------------------------------------
/tests/perf_test/bert/server.py:
--------------------------------------------------------------------------------
1 | """A BERT-Large text classification server with batching to be used for benchmarking."""
2 |
3 | import torch
4 | from jsonargparse import CLI
5 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, BertConfig
6 |
7 | import litserve as ls
8 |
9 | # Set float32 matrix multiplication precision if GPU is available and capable
10 | if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0):
11 | torch.set_float32_matmul_precision("high")
12 |
13 | # set dtype to bfloat16 if CUDA is available
14 | dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
15 |
16 |
17 | class HuggingFaceLitAPI(ls.LitAPI):
18 | def setup(self, device):
19 | print(device)
20 | model_name = "google-bert/bert-large-uncased"
21 | config = BertConfig.from_pretrained(pretrained_model_name_or_path=model_name)
22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23 | self.model = AutoModelForSequenceClassification.from_config(config, torch_dtype=dtype)
24 | self.model.to(device)
25 |
26 | def decode_request(self, request: dict):
27 | return request["text"]
28 |
29 | def batch(self, inputs):
30 | return self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True)
31 |
32 | def predict(self, inputs):
33 | inputs = {key: value.to(self.device) for key, value in inputs.items()}
34 | with torch.inference_mode():
35 | outputs = self.model(**inputs)
36 | logits = outputs.logits
37 | return torch.argmax(logits, dim=1)
38 |
39 | def unbatch(self, outputs):
40 | return outputs.tolist()
41 |
42 | def encode_response(self, output):
43 | return {"label_idx": output}
44 |
45 |
46 | def main(
47 | batch_size: int = 10,
48 | batch_timeout: float = 0.01,
49 | devices: int = 2,
50 | workers_per_device=2,
51 | ):
52 | print(locals())
53 | api = HuggingFaceLitAPI()
54 | server = ls.LitServer(
55 | api,
56 | max_batch_size=batch_size,
57 | batch_timeout=batch_timeout,
58 | workers_per_device=workers_per_device,
59 | accelerator="auto",
60 | devices=devices,
61 | timeout=200,
62 | fast_queue=True,
63 | )
64 | server.run(log_level="warning", num_api_servers=4, generate_client_file=False)
65 |
66 |
67 | if __name__ == "__main__":
68 | CLI(main)
69 |
--------------------------------------------------------------------------------
/tests/perf_test/bert/utils.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import random
3 | import time
4 |
5 | import gpustat
6 | import psutil
7 | import requests
8 |
9 | from tests.perf_test.bert.data import phrases
10 |
11 |
12 | def create_random_batch(size: int):
13 | result = []
14 | for i in range(size):
15 | result.append(random.choice(phrases))
16 |
17 | return result
18 |
19 |
20 | # Configuration
21 | SERVER_URL = "http://0.0.0.0:8000/predict"
22 |
23 | session = requests.Session()
24 |
25 | executor = None
26 |
27 |
28 | def send_request():
29 | """Function to send a single request and measure the response time."""
30 | payload = {"text": random.choice(phrases)}
31 | start_time = time.time()
32 | response = session.post(SERVER_URL, json=payload)
33 | end_time = time.time()
34 | return end_time - start_time, response.status_code
35 |
36 |
37 | def benchmark(num_requests=1000, concurrency_level=50, print_metrics=True):
38 | """Benchmark the ML server."""
39 | global executor
40 | if executor is None:
41 | print("creating executor")
42 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=concurrency_level)
43 |
44 | if executor._max_workers < concurrency_level:
45 | print("updating executor")
46 | executor = concurrent.futures.ThreadPoolExecutor(max_workers=concurrency_level)
47 |
48 | start_benchmark_time = time.time() # Start benchmark timing
49 | futures = [executor.submit(send_request) for _ in range(num_requests)]
50 | response_times = []
51 | status_codes = []
52 |
53 | for future in concurrent.futures.as_completed(futures):
54 | response_time, status_code = future.result()
55 | response_times.append(response_time)
56 | status_codes.append(status_code)
57 |
58 | end_benchmark_time = time.time() # End benchmark timing
59 | total_benchmark_time = end_benchmark_time - start_benchmark_time # Time in seconds
60 |
61 | # Analysis
62 | total_time = sum(response_times) # Time in seconds
63 | avg_time = total_time / num_requests # Time in seconds
64 | avg_latency_per_request = (total_time / num_requests) * 1000 # Convert to milliseconds
65 | success_rate = status_codes.count(200) / num_requests * 100
66 | rps = num_requests / total_benchmark_time # Requests per second
67 |
68 | # Calculate throughput per concurrent user in requests per second
69 | successful_requests = status_codes.count(200)
70 | throughput_per_user = (successful_requests / total_benchmark_time) / concurrency_level # Requests per second
71 |
72 | # Create a dictionary with the metrics
73 | metrics = {
74 | "Total Requests": num_requests,
75 | "Concurrency Level": concurrency_level,
76 | "Total Benchmark Time (seconds)": total_benchmark_time,
77 | "Average Response Time (ms)": avg_time * 1000,
78 | "Success Rate (%)": success_rate,
79 | "Requests Per Second (RPS)": rps,
80 | "Latency per Request (ms)": avg_latency_per_request,
81 | "Throughput per Concurrent User (requests/second)": throughput_per_user,
82 | }
83 | try:
84 | gpu_stats = gpustat.GPUStatCollection.new_query()
85 | metrics["GPU Utilization"] = sum([gpu.utilization for gpu in gpu_stats.gpus]) # / len(gpu_stats.gpus)
86 | except Exception:
87 | metrics["GPU Utilization"] = -1
88 | metrics["CPU Usage"] = psutil.cpu_percent(0.5)
89 |
90 | # Print the metrics
91 | if print_metrics:
92 | for key, value in metrics.items():
93 | print(f"{key}: {value}")
94 | print("-" * 50)
95 |
96 | return metrics
97 |
--------------------------------------------------------------------------------
/tests/perf_test/stream/run_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # 1. Test server streams data very fast
3 |
4 | # Function to clean up server process
5 | cleanup() {
6 | pkill -f "python tests/perf_test/stream/stream_speed/server.py"
7 | }
8 |
9 | # Trap script exit to run cleanup
10 | trap cleanup EXIT
11 |
12 | # Start the server in the background and capture its PID
13 | python tests/perf_test/stream/stream_speed/server.py &
14 | SERVER_PID=$!
15 |
16 | echo "Server started with PID $SERVER_PID"
17 |
18 | # Run your benchmark script
19 | echo "Preparing to run benchmark.py..."
20 |
21 | export PYTHONPATH=$PWD && python tests/perf_test/stream/stream_speed/benchmark.py
22 |
23 | # Check if benchmark.py exited successfully
24 | if [ $? -ne 0 ]; then
25 | echo "benchmark.py failed to run successfully."
26 | exit 1
27 | else
28 | echo "benchmark.py ran successfully."
29 | fi
30 |
--------------------------------------------------------------------------------
/tests/perf_test/stream/stream_speed/benchmark.py:
--------------------------------------------------------------------------------
1 | """Consume 10K tokens from the stream endpoint and measure the speed."""
2 |
3 | import logging
4 | import time
5 |
6 | import requests
7 | from tenacity import retry, stop_after_attempt
8 |
9 | logger = logging.getLogger(__name__)
10 | # Configuration
11 | SERVER_URL = "http://0.0.0.0:8000/predict"
12 | TOTAL_TOKENS = 10000
13 | EXPECTED_TTFT = 0.005 # time to first token
14 |
15 | # tokens per second
16 | MAX_SPEED = 3600 # 3600 on GitHub CI, 10000 on M3 Pro
17 |
18 | session = requests.Session()
19 |
20 |
21 | def speed_test():
22 | start = time.time()
23 | resp = session.post(SERVER_URL, stream=True, json={"input": 1})
24 | num_tokens = 0
25 | ttft = None # time to first token
26 | for line in resp.iter_lines():
27 | if not line:
28 | continue
29 | if ttft is None:
30 | ttft = time.time() - start
31 | print(f"Time to first token: {ttft}")
32 | assert ttft < EXPECTED_TTFT, f"Expected time to first token to be less than 0.1 seconds but got {ttft}"
33 | num_tokens += 1
34 | end = time.time()
35 | resp.raise_for_status()
36 | assert num_tokens == TOTAL_TOKENS, f"Expected {TOTAL_TOKENS} tokens, got {num_tokens}"
37 | speed = num_tokens / (end - start)
38 | return {"speed": speed, "time": end - start}
39 |
40 |
41 | @retry(stop=stop_after_attempt(10))
42 | def main():
43 | for i in range(10):
44 | try:
45 | resp = requests.get("http://localhost:8000/health")
46 | if resp.status_code == 200:
47 | break
48 | except requests.exceptions.ConnectionError as e:
49 | logger.error(f"Error connecting to server: {e}")
50 | time.sleep(10)
51 | data = speed_test()
52 | speed = data["speed"]
53 | print(data)
54 | assert speed >= MAX_SPEED, f"Expected streaming speed to be greater than {MAX_SPEED}, got {speed}"
55 |
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/tests/perf_test/stream/stream_speed/server.py:
--------------------------------------------------------------------------------
1 | import litserve as ls
2 |
3 |
4 | class SimpleStreamingAPI(ls.LitAPI):
5 | def setup(self, device) -> None:
6 | self.model = None
7 |
8 | def decode_request(self, request):
9 | return request["input"]
10 |
11 | def predict(self, x):
12 | yield from range(10000)
13 |
14 | def encode_response(self, output_stream):
15 | for output in output_stream:
16 | yield {"output": output}
17 |
18 |
19 | if __name__ == "__main__":
20 | api = SimpleStreamingAPI()
21 | server = ls.LitServer(
22 | api,
23 | stream=True,
24 | )
25 | server.run(port=8000, generate_client_file=False)
26 |
--------------------------------------------------------------------------------
/tests/simple_server.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from fastapi import Request, Response
15 |
16 | from litserve.api import LitAPI
17 | from litserve.server import LitServer
18 |
19 |
20 | class SimpleLitAPI(LitAPI):
21 | def setup(self, device):
22 | self.model = lambda x: x**2
23 |
24 | def decode_request(self, request: Request):
25 | return request["input"]
26 |
27 | def predict(self, x):
28 | return self.model(x)
29 |
30 | def encode_response(self, output) -> Response:
31 | return {"output": output}
32 |
33 |
34 | if __name__ == "__main__":
35 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=10)
36 | server.run()
37 |
--------------------------------------------------------------------------------
/tests/simple_server_diff_port.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from litserve.api import LitAPI
15 | from litserve.server import LitServer
16 |
17 |
18 | class SimpleLitAPI(LitAPI):
19 | def setup(self, device):
20 | self.model = lambda x: x**2
21 |
22 | def decode_request(self, request):
23 | return request["input"]
24 |
25 | def predict(self, x):
26 | return self.model(x)
27 |
28 | def encode_response(self, output):
29 | return {"output": output}
30 |
31 |
32 | if __name__ == "__main__":
33 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=10)
34 | server.run(port=8080)
35 |
--------------------------------------------------------------------------------
/tests/test_auth.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from fastapi import Depends, HTTPException, Request, Response
16 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
17 | from fastapi.testclient import TestClient
18 |
19 | import litserve.server
20 | from litserve import LitAPI, LitServer
21 | from litserve.utils import wrap_litserve_start
22 |
23 |
24 | class SimpleAuthedLitAPI(LitAPI):
25 | def setup(self, device):
26 | self.model = lambda x: x**2
27 |
28 | def decode_request(self, request: Request):
29 | return request["input"]
30 |
31 | def predict(self, x):
32 | return self.model(x)
33 |
34 | def encode_response(self, output) -> Response:
35 | return {"output": output}
36 |
37 | def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
38 | if auth.scheme != "Bearer" or auth.credentials != "1234":
39 | raise HTTPException(status_code=401, detail="Bad token")
40 |
41 |
42 | def test_authorized_custom():
43 | server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)
44 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
45 | input = {"input": 4.0}
46 | response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input)
47 | assert response.status_code == 200
48 |
49 |
50 | def test_not_authorized_custom():
51 | server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)
52 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
53 | input = {"input": 4.0}
54 | response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input)
55 | assert response.status_code == 401
56 |
57 |
58 | class SimpleLitAPI(LitAPI):
59 | def setup(self, device):
60 | self.model = lambda x: x**2
61 |
62 | def decode_request(self, request: Request):
63 | return request["input"]
64 |
65 | def predict(self, x):
66 | return self.model(x)
67 |
68 | def encode_response(self, output) -> Response:
69 | return {"output": output}
70 |
71 |
72 | def test_authorized_api_key():
73 | litserve.server.LIT_SERVER_API_KEY = "abcd"
74 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)
75 |
76 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
77 | input = {"input": 4.0}
78 | response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input)
79 | assert response.status_code == 200
80 |
81 | litserve.server.LIT_SERVER_API_KEY = None
82 |
83 |
84 | def test_not_authorized_api_key():
85 | litserve.server.LIT_SERVER_API_KEY = "abcd"
86 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)
87 |
88 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
89 | input = {"input": 4.0}
90 | response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input)
91 | assert response.status_code == 401
92 |
93 | litserve.server.LIT_SERVER_API_KEY = None
94 |
--------------------------------------------------------------------------------
/tests/test_callbacks.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import re
3 | import time
4 |
5 | import pytest
6 | from asgi_lifespan import LifespanManager
7 | from fastapi.testclient import TestClient
8 | from httpx import ASGITransport, AsyncClient
9 |
10 | import litserve as ls
11 | from litserve.callbacks import CallbackRunner, EventTypes
12 | from litserve.callbacks.defaults import PredictionTimeLogger
13 | from litserve.callbacks.defaults.metric_callback import RequestTracker
14 | from litserve.utils import wrap_litserve_start
15 |
16 |
17 | async def run_simple_request(server, num_requests=1):
18 | with wrap_litserve_start(server) as server:
19 | async with LifespanManager(server.app) as manager, AsyncClient(
20 | transport=ASGITransport(app=manager.app), base_url="http://test"
21 | ) as ac:
22 | responses = [ac.post("/predict", json={"input": 4.0}) for _ in range(num_requests)]
23 | responses = await asyncio.gather(*responses)
24 | for response in responses:
25 | assert response.json() == {"output": 16.0}, "Unexpected response"
26 |
27 |
28 | class SlowAPI(ls.test_examples.SimpleLitAPI):
29 | def predict(self, x):
30 | time.sleep(1)
31 | return super().predict(x)
32 |
33 |
34 | def test_callback_runner():
35 | cb_runner = CallbackRunner()
36 | assert cb_runner._callbacks == [], "Callbacks list must be empty"
37 |
38 | cb = PredictionTimeLogger()
39 | cb_runner._add_callbacks(cb)
40 | assert cb_runner._callbacks == [cb], "Callback not added to runner"
41 |
42 |
43 | def test_callback(capfd):
44 | lit_api = ls.test_examples.SimpleLitAPI()
45 | server = ls.LitServer(lit_api, callbacks=[PredictionTimeLogger()])
46 |
47 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
48 | response = client.post("/predict", json={"input": 4.0})
49 | assert response.json() == {"output": 16.0}
50 |
51 | captured = capfd.readouterr()
52 | pattern = r"Prediction took \d+\.\d{2} seconds"
53 | assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}"
54 |
55 |
56 | def test_metric_logger(capfd):
57 | cb = PredictionTimeLogger()
58 | cb_runner = CallbackRunner()
59 | cb_runner._add_callbacks(cb)
60 | assert cb_runner._callbacks == [cb], "Callback not added to runner"
61 | cb_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=None)
62 | cb_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=None)
63 |
64 | captured = capfd.readouterr()
65 | pattern = r"Prediction took \d+\.\d{2} seconds"
66 | assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}"
67 |
68 |
69 | @pytest.mark.asyncio
70 | async def test_request_tracker(capfd):
71 | lit_api = SlowAPI()
72 |
73 | server = ls.LitServer(lit_api, track_requests=False, callbacks=[RequestTracker()])
74 | await run_simple_request(server, 1)
75 | captured = capfd.readouterr()
76 | assert "Active requests: None" in captured.out, f"Expected pattern not found in output: {captured.out}"
77 |
78 | server = ls.LitServer(lit_api, track_requests=True, callbacks=[RequestTracker()])
79 | await run_simple_request(server, 4)
80 | captured = capfd.readouterr()
81 | assert "Active requests: 4" in captured.out, f"Expected pattern not found in output: {captured.out}"
82 |
--------------------------------------------------------------------------------
/tests/test_cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from unittest.mock import MagicMock, patch
4 |
5 | import pytest
6 |
7 | from litserve.__main__ import main
8 | from litserve.cli import _ensure_lightning_installed
9 | from litserve.cli import main as cli_main
10 |
11 |
12 | def test_dockerize_help(monkeypatch, capsys):
13 | monkeypatch.setattr("sys.argv", ["litserve", "dockerize", "--help"])
14 | # argparse calls sys.exit() after printing help
15 | with pytest.raises(SystemExit):
16 | main()
17 | captured = capsys.readouterr()
18 | assert "usage:" in captured.out, "CLI did not print help message"
19 | assert "The path to the server file." in captured.out, "CLI did not print help message"
20 |
21 |
22 | def test_dockerize_command(monkeypatch, capsys):
23 | # Assuming you have a dummy server file for testing
24 | dummy_server_file = "dummy_server.py"
25 | with open(dummy_server_file, "w") as f:
26 | f.write("# Dummy server file for testing\n")
27 |
28 | monkeypatch.setattr("sys.argv", ["litserve", "dockerize", dummy_server_file])
29 | main()
30 | captured = capsys.readouterr()
31 | os.remove(dummy_server_file)
32 | assert "Dockerfile created successfully" in captured.out, "CLI did not create Dockerfile"
33 | assert os.path.exists("Dockerfile"), "CLI did not create Dockerfile"
34 |
35 |
36 | @patch("importlib.util.find_spec")
37 | @patch("subprocess.check_call")
38 | def test_ensure_lightning_installed(mock_check_call, mock_find_spec):
39 | mock_find_spec.return_value = False
40 | _ensure_lightning_installed()
41 | mock_check_call.assert_called_once_with([sys.executable, "-m", "pip", "install", "-U", "lightning-sdk"])
42 |
43 |
44 | # TODO: Remove this once we have a fix for Python 3.9 and 3.10
45 | @pytest.mark.skipif(sys.version_info[:2] in [(3, 9), (3, 10)], reason="Test fails on Python 3.9 and 3.10")
46 | @patch("importlib.util.find_spec")
47 | @patch("subprocess.check_call")
48 | @patch("builtins.__import__")
49 | def test_cli_main_lightning_not_installed(mock_import, mock_check_call, mock_find_spec):
50 | # Create a mock for the lightning_sdk module and its components
51 | mock_lightning_sdk = MagicMock()
52 | mock_lightning_sdk.cli.entrypoint.main_cli = MagicMock()
53 |
54 | # Configure __import__ to return our mock when lightning_sdk is imported
55 | def side_effect(name, *args, **kwargs):
56 | if name == "lightning_sdk.cli.entrypoint":
57 | return mock_lightning_sdk
58 | return __import__(name, *args, **kwargs)
59 |
60 | mock_import.side_effect = side_effect
61 |
62 | # Test when lightning_sdk is not installed but gets installed dynamically
63 | mock_find_spec.side_effect = [False, True] # First call returns False, second call returns True
64 | test_args = ["lightning", "run", "app", "app.py"]
65 |
66 | with patch.object(sys, "argv", test_args):
67 | cli_main()
68 |
69 | mock_check_call.assert_called_once_with([sys.executable, "-m", "pip", "install", "-U", "lightning-sdk"])
70 |
71 |
72 | @pytest.mark.skipif(sys.version_info[:2] in [(3, 9), (3, 10)], reason="Test fails on Python 3.9 and 3.10")
73 | @patch("importlib.util.find_spec")
74 | @patch("builtins.__import__")
75 | def test_cli_main_import_error(mock_import, mock_find_spec, capsys):
76 | # Set up the mock to raise ImportError specifically for lightning_sdk import
77 | def import_mock(name, *args, **kwargs):
78 | if name == "lightning_sdk.cli.entrypoint":
79 | raise ImportError("Module not found")
80 | return __import__(name, *args, **kwargs)
81 |
82 | mock_import.side_effect = import_mock
83 |
84 | # Mock find_spec to return True so we attempt the import
85 | mock_find_spec.return_value = True
86 | test_args = ["lightning", "deploy", "api", "app.py"]
87 |
88 | with patch.object(sys, "argv", test_args): # noqa: SIM117
89 | with pytest.raises(SystemExit) as excinfo:
90 | cli_main()
91 |
92 | assert excinfo.value.code == 1
93 | captured = capsys.readouterr()
94 | assert "Error importing lightning_sdk CLI" in captured.out
95 |
--------------------------------------------------------------------------------
/tests/test_compression.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from fastapi import Request, Response
16 | from fastapi.testclient import TestClient
17 |
18 | from litserve import LitAPI, LitServer
19 | from litserve.utils import wrap_litserve_start
20 |
21 | # trivially compressible content
22 | test_output = {"result": "0" * 100000}
23 |
24 |
25 | class LargeOutputLitAPI(LitAPI):
26 | def setup(self, device):
27 | pass
28 |
29 | def decode_request(self, request: Request):
30 | pass
31 |
32 | def predict(self, x):
33 | pass
34 |
35 | def encode_response(self, output) -> Response:
36 | return test_output
37 |
38 |
39 | def test_compression():
40 | server = LitServer(LargeOutputLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)
41 |
42 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
43 | # compressed
44 | response = client.post("/predict", headers={"Accept-Encoding": "gzip"}, json={})
45 | assert response.status_code == 200
46 | assert response.headers["Content-Encoding"] == "gzip"
47 | content_length = int(response.headers["Content-Length"])
48 | assert 0 < content_length < 100000
49 | assert response.json() == test_output
50 |
51 | # uncompressed
52 | response = client.post("/predict", headers={"Accept-Encoding": ""}, json={})
53 | assert response.status_code == 200
54 | assert "Content-Encoding" not in response.headers
55 | content_length = int(response.headers["Content-Length"])
56 | assert content_length > 100000
57 | assert response.json() == test_output
58 |
--------------------------------------------------------------------------------
/tests/test_connector.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from unittest.mock import patch
15 |
16 | import pytest
17 | import torch
18 |
19 | from litserve.connector import _Connector, check_cuda_with_nvidia_smi
20 |
21 |
22 | @pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU")
23 | def test_check_cuda_with_nvidia_smi():
24 | assert check_cuda_with_nvidia_smi() == torch.cuda.device_count()
25 |
26 |
27 | @pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Non Nvidia GPU only")
28 | @patch(
29 | "litserve.connector.subprocess.check_output",
30 | return_value=b"GPU 0: NVIDIA GeForce RTX 4090 (UUID: GPU-rb438fre-0ar-9702-de35-ref4rjn34omk3 )",
31 | )
32 | def test_check_cuda_with_nvidia_smi_mock_gpu(mock_subprocess):
33 | check_cuda_with_nvidia_smi.cache_clear()
34 | assert check_cuda_with_nvidia_smi() == 1
35 | check_cuda_with_nvidia_smi.cache_clear()
36 |
37 |
38 | @pytest.mark.parametrize(
39 | ("input_accelerator", "expected_accelerator", "expected_devices"),
40 | [
41 | ("cpu", "cpu", 1),
42 | pytest.param(
43 | "cuda",
44 | "cuda",
45 | torch.cuda.device_count(),
46 | marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"),
47 | ),
48 | pytest.param(
49 | "gpu",
50 | "cuda",
51 | torch.cuda.device_count(),
52 | marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"),
53 | ),
54 | pytest.param(
55 | None,
56 | "cuda",
57 | torch.cuda.device_count(),
58 | marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"),
59 | ),
60 | pytest.param(
61 | "auto",
62 | "cuda",
63 | torch.cuda.device_count(),
64 | marks=pytest.mark.skipif(torch.cuda.device_count() == 0, reason="Only tested on Nvidia GPU"),
65 | ),
66 | pytest.param(
67 | "auto",
68 | "mps",
69 | 1,
70 | marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"),
71 | ),
72 | pytest.param(
73 | "gpu",
74 | "mps",
75 | 1,
76 | marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"),
77 | ),
78 | pytest.param(
79 | "mps",
80 | "mps",
81 | 1,
82 | marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"),
83 | ),
84 | pytest.param(
85 | None,
86 | "mps",
87 | 1,
88 | marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="Only tested on Apple MPS"),
89 | ),
90 | ],
91 | )
92 | def test_connector(input_accelerator, expected_accelerator, expected_devices):
93 | check_cuda_with_nvidia_smi.cache_clear()
94 | connector = _Connector(accelerator=input_accelerator)
95 | assert connector.accelerator == expected_accelerator, (
96 | f"accelerator mismatch - expected: {expected_accelerator}, actual: {connector.accelerator}"
97 | )
98 |
99 | assert connector.devices == expected_devices, (
100 | f"devices mismatch - expected {expected_devices}, actual: {connector.devices}"
101 | )
102 |
103 | with pytest.raises(ValueError, match="accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'"):
104 | _Connector(accelerator="SUPER_CHIP")
105 |
106 |
107 | def test__sanitize_accelerator():
108 | assert _Connector._sanitize_accelerator(None) == "auto"
109 | assert _Connector._sanitize_accelerator("CPU") == "cpu"
110 | with pytest.raises(ValueError, match="accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'"):
111 | _Connector._sanitize_accelerator("SUPER_CHIP")
112 |
--------------------------------------------------------------------------------
/tests/test_docker_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 |
16 | import litserve as ls
17 | from litserve import docker_builder
18 |
19 |
20 | def test_color():
21 | assert docker_builder.color("hi", docker_builder.RED) == f"{docker_builder.RED}hi{docker_builder.RESET}"
22 |
23 | expected = f"{docker_builder.INFO} {docker_builder.RED}hi{docker_builder.RESET}"
24 | assert docker_builder.color("hi", docker_builder.RED, docker_builder.INFO) == expected
25 |
26 |
27 | EXPECTED_CONENT = f"""ARG PYTHON_VERSION=3.12
28 | FROM python:$PYTHON_VERSION-slim
29 |
30 | ####### Add your own installation commands here #######
31 | # RUN pip install some-package
32 | # RUN wget https://path/to/some/data/or/weights
33 | # RUN apt-get update && apt-get install -y
34 |
35 | WORKDIR /app
36 | COPY . /app
37 |
38 | # Install litserve and requirements
39 | RUN pip install --no-cache-dir litserve=={ls.__version__} -r requirements.txt
40 | EXPOSE 8000
41 | CMD ["python", "/app/app.py"]
42 | """
43 |
44 |
45 | EXPECTED_GPU_DOCKERFILE = f"""# Change CUDA and cuDNN version here
46 | FROM nvidia/cuda:12.4.1-base-ubuntu22.04
47 | ARG PYTHON_VERSION=3.12
48 |
49 | ENV DEBIAN_FRONTEND=noninteractive
50 | RUN apt-get update && apt-get install -y --no-install-recommends \\
51 | software-properties-common \\
52 | wget \\
53 | && add-apt-repository ppa:deadsnakes/ppa \\
54 | && apt-get update && apt-get install -y --no-install-recommends \\
55 | python$PYTHON_VERSION \\
56 | python$PYTHON_VERSION-dev \\
57 | python$PYTHON_VERSION-venv \\
58 | && wget https://bootstrap.pypa.io/get-pip.py -O get-pip.py \\
59 | && python$PYTHON_VERSION get-pip.py \\
60 | && rm get-pip.py \\
61 | && ln -sf /usr/bin/python$PYTHON_VERSION /usr/bin/python \\
62 | && ln -sf /usr/local/bin/pip$PYTHON_VERSION /usr/local/bin/pip \\
63 | && python --version \\
64 | && pip --version \\
65 | && apt-get purge -y --auto-remove software-properties-common \\
66 | && apt-get clean \\
67 | && rm -rf /var/lib/apt/lists/*
68 |
69 | ####### Add your own installation commands here #######
70 | # RUN pip install some-package
71 | # RUN wget https://path/to/some/data/or/weights
72 | # RUN apt-get update && apt-get install -y
73 |
74 | WORKDIR /app
75 | COPY . /app
76 |
77 | # Install litserve and requirements
78 | RUN pip install --no-cache-dir litserve=={ls.__version__} -r requirements.txt
79 | EXPOSE 8000
80 | CMD ["python", "/app/app.py"]
81 | """
82 |
83 |
84 | def test_dockerize(tmp_path, monkeypatch):
85 | with open(tmp_path / "app.py", "w") as f:
86 | f.write("print('hello')")
87 |
88 | # Temporarily change the current working directory to tmp_path
89 | monkeypatch.chdir(tmp_path)
90 |
91 | with pytest.warns(UserWarning, match="Make sure to install the required packages in the Dockerfile."):
92 | docker_builder.dockerize("app.py", 8000)
93 |
94 | with open(tmp_path / "requirements.txt", "w") as f:
95 | f.write("lightning")
96 |
97 | docker_builder.dockerize("app.py", 8000)
98 | with open("Dockerfile") as f:
99 | content = f.read()
100 | assert content == EXPECTED_CONENT
101 |
102 | docker_builder.dockerize("app.py", 8000, gpu=True)
103 | with open("Dockerfile") as f:
104 | content = f.read()
105 | assert content == EXPECTED_GPU_DOCKERFILE
106 |
107 | with pytest.raises(FileNotFoundError, match="must be in the current directory"):
108 | docker_builder.dockerize("random_file_name.py", 8000)
109 |
--------------------------------------------------------------------------------
/tests/test_form.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from fastapi import Request, Response
16 | from fastapi.testclient import TestClient
17 |
18 | from litserve import LitAPI, LitServer
19 | from litserve.utils import wrap_litserve_start
20 |
21 |
22 | class SimpleFileLitAPI(LitAPI):
23 | def setup(self, device):
24 | self.model = lambda x: x**2
25 |
26 | def decode_request(self, request: Request):
27 | return len(request["input"].file.read().decode("utf-8"))
28 |
29 | def predict(self, x):
30 | return self.model(x)
31 |
32 | def encode_response(self, output) -> Response:
33 | return {"output": output}
34 |
35 |
36 | def test_multipart_form_data(tmp_path):
37 | file_length = 1024 * 1024 * 100
38 |
39 | server = LitServer(
40 | SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length * 2)
41 | )
42 |
43 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
44 | file_path = f"{tmp_path}/big_file.txt"
45 | with open(file_path, "wb") as f:
46 | f.write(bytearray([1] * file_length))
47 | with open(file_path, "rb") as f:
48 | file = {"input": f}
49 | response = client.post("/predict", files=file)
50 | assert response.json() == {"output": file_length**2}
51 |
52 |
53 | def test_file_too_big(tmp_path):
54 | file_length = 1024 * 1024 * 100
55 |
56 | server = LitServer(
57 | SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length / 2)
58 | )
59 |
60 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
61 | file_path = f"{tmp_path}/big_file.txt"
62 | with open(file_path, "wb") as f:
63 | f.write(bytearray([1] * file_length))
64 | with open(file_path, "rb") as f:
65 | file = {"input": f}
66 | response = client.post("/predict", files=file)
67 | assert response.status_code == 413
68 |
69 | # spoof content-length size
70 | response = client.post("/predict", files=file, headers={"content-length": "1024"})
71 | assert response.status_code == 413
72 |
73 |
74 | class SimpleFormLitAPI(LitAPI):
75 | def setup(self, device):
76 | self.model = lambda x: x**2
77 |
78 | def decode_request(self, request: Request):
79 | return float(request["input"])
80 |
81 | def predict(self, x):
82 | return self.model(x)
83 |
84 | def encode_response(self, output) -> Response:
85 | return {"output": output}
86 |
87 |
88 | def test_urlencoded_form_data():
89 | server = LitServer(SimpleFormLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)
90 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
91 | file = {"input": "4.0"}
92 | response = client.post("/predict", data=file)
93 | assert response.json() == {"output": 16.0}
94 |
--------------------------------------------------------------------------------
/tests/test_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import threading
15 | import time
16 | from unittest.mock import MagicMock
17 |
18 | import pytest
19 | from fastapi.testclient import TestClient
20 |
21 | import litserve as ls
22 | from litserve.loggers import Logger, _LoggerConnector
23 | from litserve.utils import wrap_litserve_start
24 |
25 |
26 | class TestLogger(Logger):
27 | def process(self, key, value):
28 | self.processed_data = (key, value)
29 |
30 |
31 | @pytest.fixture
32 | def mock_lit_server():
33 | mock_server = MagicMock()
34 | mock_server.log_queue.get = MagicMock(return_value=("test_key", "test_value"))
35 | return mock_server
36 |
37 |
38 | @pytest.fixture
39 | def test_logger():
40 | return TestLogger()
41 |
42 |
43 | @pytest.fixture
44 | def logger_connector(mock_lit_server, test_logger):
45 | return _LoggerConnector(mock_lit_server, [test_logger])
46 |
47 |
48 | def test_logger_mount(test_logger):
49 | mock_app = MagicMock()
50 | test_logger.mount("/test", mock_app)
51 | assert test_logger._config["mount"]["path"] == "/test"
52 | assert test_logger._config["mount"]["app"] == mock_app
53 |
54 |
55 | def test_connector_add_logger(logger_connector):
56 | new_logger = TestLogger()
57 | logger_connector.add_logger(new_logger)
58 | assert new_logger in logger_connector._loggers
59 |
60 |
61 | def test_connector_mount(mock_lit_server, test_logger, logger_connector):
62 | mock_app = MagicMock()
63 | test_logger.mount("/test", mock_app)
64 | logger_connector.add_logger(test_logger)
65 | mock_lit_server.app.mount.assert_called_with("/test", mock_app)
66 |
67 |
68 | def test_invalid_loggers():
69 | _LoggerConnector(None, TestLogger())
70 | with pytest.raises(ValueError, match="Logger must be an instance of litserve.Logger"):
71 | _ = _LoggerConnector(None, [MagicMock()])
72 |
73 | with pytest.raises(ValueError, match="loggers must be a list or an instance of litserve.Logger"):
74 | _ = _LoggerConnector(None, MagicMock())
75 |
76 |
77 | class LoggerAPI(ls.test_examples.SimpleLitAPI):
78 | def predict(self, input):
79 | result = super().predict(input)
80 | for i in range(1, 5):
81 | self.log("time", i * 0.1)
82 | return result
83 |
84 |
85 | def test_server_wo_logger():
86 | api = LoggerAPI()
87 | server = ls.LitServer(api)
88 |
89 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
90 | response = client.post("/predict", json={"input": 4.0})
91 | assert response.json() == {"output": 16.0}
92 |
93 |
94 | class FileLogger(ls.Logger):
95 | def __init__(self, path="test_logger_temp.txt"):
96 | super().__init__()
97 | self.path = path
98 |
99 | def process(self, key, value):
100 | with open(self.path, "a+") as f:
101 | f.write(f"{key}: {value:.1f}\n")
102 |
103 |
104 | def test_logger_with_api(tmpdir):
105 | path = str(tmpdir / "test_logger_temp.txt")
106 | api = LoggerAPI()
107 | server = ls.LitServer(api, loggers=[FileLogger(path)])
108 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
109 | response = client.post("/predict", json={"input": 4.0})
110 | assert response.json() == {"output": 16.0}
111 | # Wait for FileLogger to write to file
112 | time.sleep(0.5)
113 | with open(path) as f:
114 | data = f.readlines()
115 | assert data == [
116 | "time: 0.1\n",
117 | "time: 0.2\n",
118 | "time: 0.3\n",
119 | "time: 0.4\n",
120 | ], f"Expected metric not found in logger file {data}"
121 |
122 |
123 | class PredictionTimeLogger(ls.Callback):
124 | def on_after_predict(self, lit_api):
125 | for i in range(1, 5):
126 | lit_api.log("time", i * 0.1)
127 |
128 |
129 | def test_logger_with_callback(tmp_path):
130 | path = str(tmp_path / "test_logger_temp.txt")
131 | api = ls.test_examples.SimpleLitAPI()
132 | server = ls.LitServer(api, loggers=[FileLogger(path)], callbacks=[PredictionTimeLogger()])
133 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
134 | response = client.post("/predict", json={"input": 4.0})
135 | assert response.json() == {"output": 16.0}
136 | # Wait for FileLogger to write to file
137 | time.sleep(0.5)
138 | with open(path) as f:
139 | data = f.readlines()
140 | assert data == [
141 | "time: 0.1\n",
142 | "time: 0.2\n",
143 | "time: 0.3\n",
144 | "time: 0.4\n",
145 | ], f"Expected metric not found in logger file {data}"
146 |
147 |
148 | class NonPickleableLogger(ls.Logger):
149 | # This is a logger that contains a non-picklable resource
150 | def __init__(self, *args, **kwargs):
151 | super().__init__(*args, **kwargs)
152 | self._lock = threading.Lock() # Non-picklable resource
153 |
154 | def process(self, key, value):
155 | with self._lock:
156 | print(f"Logged {key}: {value}", flush=True)
157 |
158 |
159 | class PickleTestAPI(ls.test_examples.SimpleLitAPI):
160 | def predict(self, x):
161 | self.log("my-key", x)
162 | return super().predict(x)
163 |
164 |
165 | def test_pickle_safety(capfd):
166 | api = PickleTestAPI()
167 | server = ls.LitServer(api, loggers=NonPickleableLogger())
168 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
169 | response = client.post("/predict", json={"input": 4.0})
170 | assert response.json() == {"output": 16.0}
171 | time.sleep(0.5)
172 | captured = capfd.readouterr()
173 | assert "Logged my-key: 4.0" in captured.out, f"Expected log not found in captured output {captured}"
174 |
--------------------------------------------------------------------------------
/tests/test_logging.py:
--------------------------------------------------------------------------------
1 | import io
2 | import logging
3 |
4 | import pytest
5 |
6 | from litserve.utils import add_log_handler, configure_logging, set_log_level
7 |
8 |
9 | @pytest.fixture
10 | def log_stream():
11 | return io.StringIO()
12 |
13 |
14 | def test_configure_logging(log_stream):
15 | # Configure logging with test stream
16 | configure_logging(level=logging.DEBUG, stream=log_stream)
17 |
18 | # Get logger and log a test message
19 | logger = logging.getLogger("litserve")
20 | test_message = "Test debug message"
21 | logger.debug(test_message)
22 |
23 | # Verify log output
24 | log_contents = log_stream.getvalue()
25 | assert test_message in log_contents
26 | assert "DEBUG" in log_contents
27 | assert logger.propagate is False
28 |
29 |
30 | def test_set_log_level():
31 | # Set log level to WARNING
32 | set_log_level(logging.WARNING)
33 |
34 | # Verify logger level
35 | logger = logging.getLogger("litserve")
36 | assert logger.level == logging.WARNING
37 |
38 |
39 | def test_add_log_handler():
40 | # Create and add a custom handler
41 | stream = io.StringIO()
42 | custom_handler = logging.StreamHandler(stream)
43 | add_log_handler(custom_handler)
44 |
45 | # Verify handler is added
46 | logger = logging.getLogger("litserve")
47 | assert custom_handler in logger.handlers
48 |
49 | # Test the handler works
50 | test_message = "Test handler message"
51 | logger.info(test_message)
52 | assert test_message in stream.getvalue()
53 |
54 |
55 | @pytest.fixture(autouse=True)
56 | def cleanup_logger():
57 | yield
58 | logger = logging.getLogger("litserve")
59 | logger.handlers.clear()
60 | logger.setLevel(logging.INFO)
61 |
--------------------------------------------------------------------------------
/tests/test_middlewares.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 | from fastapi.testclient import TestClient
16 | from starlette.middleware.base import BaseHTTPMiddleware
17 | from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
18 | from starlette.middleware.trustedhost import TrustedHostMiddleware
19 | from starlette.types import ASGIApp
20 |
21 | import litserve as ls
22 | from litserve.utils import wrap_litserve_start
23 |
24 |
25 | class RequestIdMiddleware(BaseHTTPMiddleware):
26 | def __init__(self, app: ASGIApp, length: int) -> None:
27 | self.app = app
28 | self.length = length
29 | super().__init__(app)
30 |
31 | async def dispatch(self, request, call_next):
32 | response = await call_next(request)
33 | response.headers["X-Request-Id"] = "0" * self.length
34 | return response
35 |
36 |
37 | def test_custom_middleware():
38 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=[(RequestIdMiddleware, {"length": 5})])
39 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
40 | response = client.post("/predict", json={"input": 4.0})
41 | assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
42 | assert response.json() == {"output": 16.0}, "server didn't return expected output"
43 | assert response.headers["X-Request-Id"] == "00000"
44 |
45 |
46 | def test_starlette_middlewares():
47 | middlewares = [
48 | (
49 | TrustedHostMiddleware,
50 | {
51 | "allowed_hosts": ["localhost", "127.0.0.1"],
52 | },
53 | ),
54 | HTTPSRedirectMiddleware,
55 | ]
56 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=middlewares)
57 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
58 | response = client.post("/predict", json={"input": 4.0}, headers={"Host": "localhost"})
59 | assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
60 | assert response.json() == {"output": 16.0}, "server didn't return expected output"
61 |
62 | response = client.post("/predict", json={"input": 4.0}, headers={"Host": "not-trusted-host"})
63 | assert response.status_code == 400, f"Expected response to be 400 but got {response.status_code}"
64 |
65 |
66 | def test_middlewares_inputs():
67 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=[])
68 | assert len(server.middlewares) == 1, "Default middleware should be present"
69 |
70 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=[], max_payload_size=1000)
71 | assert len(server.middlewares) == 2, "Default middleware should be present"
72 |
73 | server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=None)
74 | assert len(server.middlewares) == 1, "Default middleware should be present"
75 |
76 | with pytest.raises(ValueError, match="middlewares must be a list of tuples"):
77 | ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5}))
78 |
--------------------------------------------------------------------------------
/tests/test_multiple_endpoints.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from asgi_lifespan import LifespanManager
3 | from httpx import ASGITransport, AsyncClient
4 |
5 | import litserve as ls
6 | from litserve.utils import wrap_litserve_start
7 |
8 |
9 | class InferencePipeline(ls.LitAPI):
10 | def __init__(self, name=None, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.name = name
13 |
14 | def setup(self, device):
15 | self.model = lambda x: x**2
16 |
17 | def decode_request(self, request):
18 | return request["input"]
19 |
20 | def predict(self, x):
21 | return self.model(x)
22 |
23 | def encode_response(self, output):
24 | return {"output": output, "name": self.name}
25 |
26 |
27 | @pytest.mark.asyncio
28 | async def test_multiple_endpoints():
29 | api1 = InferencePipeline(name="api1", api_path="/api1")
30 | api2 = InferencePipeline(name="api2", api_path="/api2")
31 | server = ls.LitServer([api1, api2])
32 |
33 | with wrap_litserve_start(server) as server:
34 | async with LifespanManager(server.app) as manager, AsyncClient(
35 | transport=ASGITransport(app=manager.app), base_url="http://test"
36 | ) as ac:
37 | resp = await ac.post("/api1", json={"input": 2.0}, timeout=10)
38 | assert resp.status_code == 200, "Server response should be 200 (OK)"
39 | assert resp.json()["output"] == 4.0, "output from Identity server must be same as input"
40 | assert resp.json()["name"] == "api1", "name from Identity server must be same as input"
41 |
42 | resp = await ac.post("/api2", json={"input": 5.0}, timeout=10)
43 | assert resp.status_code == 200, "Server response should be 200 (OK)"
44 | assert resp.json()["output"] == 25.0, "output from Identity server must be same as input"
45 | assert resp.json()["name"] == "api2", "name from Identity server must be same as input"
46 |
47 |
48 | def test_multiple_endpoints_with_same_path():
49 | api1 = InferencePipeline(name="api1", api_path="/api1")
50 | api2 = InferencePipeline(name="api2", api_path="/api1")
51 | with pytest.raises(ValueError, match="api_path /api1 is already in use by"):
52 | ls.LitServer([api1, api2])
53 |
54 |
55 | def test_reserved_paths():
56 | api1 = InferencePipeline(name="api1", api_path="/health")
57 | api2 = InferencePipeline(name="api2", api_path="/info")
58 | with pytest.raises(ValueError, match="api_path /health is already in use by LitServe healthcheck"):
59 | ls.LitServer([api1, api2])
60 |
--------------------------------------------------------------------------------
/tests/test_pydantic.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from fastapi.testclient import TestClient
15 | from pydantic import BaseModel
16 |
17 | from litserve import LitAPI, LitServer
18 | from litserve.utils import wrap_litserve_start
19 |
20 |
21 | class PredictRequest(BaseModel):
22 | input: float
23 |
24 |
25 | class PredictResponse(BaseModel):
26 | output: float
27 |
28 |
29 | class SimpleLitAPI(LitAPI):
30 | def setup(self, device):
31 | self.model = lambda x: x**2
32 |
33 | def decode_request(self, request: PredictRequest) -> float:
34 | return request.input
35 |
36 | def predict(self, x):
37 | return self.model(x)
38 |
39 | def encode_response(self, output: float) -> PredictResponse:
40 | return PredictResponse(output=output)
41 |
42 |
43 | def test_pydantic():
44 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=5)
45 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
46 | response = client.post("/predict", json={"input": 4.0})
47 | assert response.json() == {"output": 16.0}
48 |
--------------------------------------------------------------------------------
/tests/test_readme.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Code extraction adapted from https://github.com/tassaron/get_code_from_markdown
15 | import re
16 | import selectors
17 | import subprocess
18 | import sys
19 | import time
20 | from typing import List
21 |
22 | import pytest
23 | from tqdm import tqdm
24 |
25 | uvicorn_msg = "Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)"
26 |
27 |
28 | def extract_code_blocks(lines: List[str]) -> List[str]:
29 | language = "python"
30 | regex = re.compile(
31 | r"(?P^```(?P(\w|-)+)\n)(?P.*?\n)(?P```)",
32 | re.DOTALL | re.MULTILINE,
33 | )
34 | blocks = [(match.group("block_language"), match.group("code")) for match in regex.finditer("".join(lines))]
35 | return [block for block_language, block in blocks if block_language == language]
36 |
37 |
38 | def get_code_blocks(file: str) -> List[str]:
39 | with open(file) as f:
40 | lines = list(f)
41 | return extract_code_blocks(lines)
42 |
43 |
44 | def get_extra_time(content: str) -> int:
45 | if "torch" in content or "transformers" in content:
46 | return 5
47 |
48 | return 0
49 |
50 |
51 | def run_script_with_timeout(file, timeout, extra_time, killall):
52 | sel = selectors.DefaultSelector()
53 | try:
54 | process = subprocess.Popen(
55 | ["python", str(file)],
56 | stdout=subprocess.PIPE,
57 | stderr=subprocess.PIPE,
58 | bufsize=1, # Line-buffered
59 | universal_newlines=True, # Decode bytes to string
60 | )
61 |
62 | stdout_lines = []
63 | stderr_lines = []
64 | end_time = time.time() + timeout + extra_time
65 |
66 | sel.register(process.stdout, selectors.EVENT_READ)
67 | sel.register(process.stderr, selectors.EVENT_READ)
68 |
69 | while True:
70 | timeout_remaining = end_time - time.time()
71 | if timeout_remaining <= 0:
72 | killall(process)
73 | break
74 |
75 | events = sel.select(timeout=timeout_remaining)
76 | for key, _ in events:
77 | if key.fileobj is process.stdout:
78 | line = process.stdout.readline()
79 | if line:
80 | stdout_lines.append(line)
81 | elif key.fileobj is process.stderr:
82 | line = process.stderr.readline()
83 | if line:
84 | stderr_lines.append(line)
85 |
86 | if process.poll() is not None:
87 | break
88 |
89 | output = "".join(stdout_lines)
90 | errors = "".join(stderr_lines)
91 |
92 | # Get the return code of the process
93 | returncode = process.returncode
94 |
95 | except Exception as e:
96 | output = ""
97 | errors = str(e)
98 | returncode = -1 # Indicate failure in running the process
99 |
100 | return returncode, output, errors
101 |
102 |
103 | @pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows CI is slow and this test is just a sanity check.")
104 | def test_readme(tmp_path, killall):
105 | d = tmp_path / "readme_codes"
106 | d.mkdir(exist_ok=True)
107 | code_blocks = get_code_blocks("README.md")
108 | assert len(code_blocks) > 0, "No code block found in README.md"
109 |
110 | for i, code in enumerate(tqdm(code_blocks)):
111 | file = d / f"{i}.py"
112 | file.write_text(code)
113 | extra_time = get_extra_time(code)
114 |
115 | returncode, stdout, stderr = run_script_with_timeout(file, timeout=5, extra_time=extra_time, killall=killall)
116 |
117 | if "server.run" in code:
118 | assert uvicorn_msg in stderr, f"Expected to run uvicorn server.\nCode:\n {code}\n\nCode output: {stderr}"
119 | elif "requests.post" in code:
120 | assert "ConnectionError" in stderr, (
121 | f"Client examples should fail with a ConnectionError because there is no server running.\nCode:\n{code}"
122 | )
123 | else:
124 | assert returncode == 0, (
125 | f"Code exited with {returncode}.\n"
126 | f"Error: {stderr}\n"
127 | f"Please check the code for correctness:\n```\n{code}\n```"
128 | )
129 |
--------------------------------------------------------------------------------
/tests/test_request_handlers.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import json
15 | from queue import Queue
16 | from unittest import mock
17 | from unittest.mock import AsyncMock, MagicMock, patch
18 |
19 | import pytest
20 | from fastapi import Request
21 |
22 | from litserve.server import BaseRequestHandler, RegularRequestHandler
23 | from litserve.test_examples import SimpleLitAPI
24 | from litserve.utils import LitAPIStatus
25 |
26 |
27 | @pytest.fixture
28 | def mock_lit_api():
29 | return SimpleLitAPI()
30 |
31 |
32 | class MockServer:
33 | def __init__(self, lit_api):
34 | self.lit_api = lit_api
35 | self.response_buffer = {}
36 | self.request_queue = Queue()
37 | self._callback_runner = mock.MagicMock()
38 | self.app = mock.MagicMock()
39 | self.app.response_queue_id = 0
40 | self.active_requests = 0
41 |
42 | def _get_request_queue(self, api_path):
43 | return self.request_queue
44 |
45 |
46 | class MockRequest:
47 | """Mock FastAPI Request object for testing."""
48 |
49 | def __init__(self, json_data=None, form_data=None, content_type="application/json"):
50 | self._json_data = json_data or {}
51 | self._form_data = form_data or {}
52 | self.headers = {"Content-Type": content_type}
53 |
54 | async def json(self):
55 | if self._json_data is None:
56 | raise json.JSONDecodeError("Invalid JSON", "", 0)
57 | return self._json_data
58 |
59 | async def form(self):
60 | return self._form_data
61 |
62 |
63 | class TestRequestHandler(BaseRequestHandler):
64 | def __init__(self, lit_api, server):
65 | super().__init__(lit_api, server)
66 | self.litapi_request_queues = {"/predict": Queue()}
67 |
68 | async def handle_request(self, request, request_type):
69 | payload = await self._prepare_request(request, request_type)
70 | uid, response_queue_id = await self._submit_request(payload)
71 | return response_queue_id
72 |
73 |
74 | @pytest.mark.asyncio
75 | async def test_request_handler(mock_lit_api):
76 | mock_server = MockServer(mock_lit_api)
77 | handler = TestRequestHandler(mock_lit_api, mock_server)
78 | mock_request = MockRequest()
79 | response_queue_id = await handler.handle_request(mock_request, Request)
80 | assert response_queue_id == 0
81 |
82 |
83 | @pytest.mark.asyncio
84 | @patch("litserve.server.asyncio.Event")
85 | async def test_request_handler_streaming(mock_event, mock_lit_api):
86 | mock_event.return_value = AsyncMock()
87 | mock_server = MockServer(mock_lit_api)
88 | mock_request = MockRequest()
89 | mock_server.response_buffer = MagicMock()
90 | mock_server.response_buffer.pop.return_value = ("test-response", LitAPIStatus.OK)
91 | handler = RegularRequestHandler(mock_lit_api, mock_server)
92 | response = await handler.handle_request(mock_request, Request)
93 | assert mock_server.request_queue.qsize() == 1
94 | assert response == "test-response"
95 |
--------------------------------------------------------------------------------
/tests/test_schema.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import base64
15 | import io
16 | import os
17 |
18 | import numpy as np
19 | from fastapi.testclient import TestClient
20 | from PIL import Image
21 |
22 | import litserve as ls
23 | from litserve.schema.image import ImageInput, ImageOutput
24 | from litserve.utils import wrap_litserve_start
25 |
26 |
27 | class ImageAPI(ls.LitAPI):
28 | def setup(self, device):
29 | self.model = lambda x: np.array(x) * 2
30 |
31 | def decode_request(self, request: ImageInput):
32 | return request.get_image()
33 |
34 | def predict(self, x):
35 | return self.model(x)
36 |
37 | def encode_response(self, numpy_image) -> ImageOutput:
38 | output = Image.fromarray(np.uint8(numpy_image)).convert("RGB")
39 | return ImageOutput(image=output)
40 |
41 |
42 | def test_image_input_output(tmpdir):
43 | path = os.path.join(tmpdir, "test.png")
44 | server = ls.LitServer(ImageAPI(), accelerator="cpu", devices=1, workers_per_device=1)
45 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
46 | Image.new("RGB", (32, 32)).save(path)
47 | with open(path, "rb") as image_file:
48 | encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
49 | response = client.post("/predict", json={"image_data": encoded_string})
50 |
51 | assert response.status_code == 200, f"Unexpected status code: {response.status_code}"
52 | image_data = response.json()["image"]
53 | image = Image.open(io.BytesIO(base64.b64decode(image_data)))
54 | assert image.size == (32, 32), f"Unexpected image size: {image.size}"
55 |
56 |
57 | class MultiImageInputModel(ImageInput):
58 | image_0: str
59 | image_1: str
60 | image_2: str
61 |
62 |
63 | class MultiImageInputAPI(ImageAPI):
64 | def decode_request(self, request: MultiImageInputModel):
65 | images = [request.get_image(f"image_{i}") for i in range(3)]
66 | for image in images:
67 | assert isinstance(image, Image.Image)
68 | return images[0]
69 |
70 |
71 | def test_multiple_image_input(tmpdir):
72 | path = os.path.join(tmpdir, "test.png")
73 | server = ls.LitServer(MultiImageInputAPI(), accelerator="cpu", devices=1, workers_per_device=1)
74 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
75 | data = {}
76 | for i in range(3):
77 | Image.new("RGB", (32, 32)).save(path)
78 | with open(path, "rb") as image_file:
79 | encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
80 | data[f"image_{i}"] = encoded_string
81 | response = client.post("/predict", json=data)
82 |
83 | assert response.status_code == 200, f"Unexpected status code: {response.status_code}"
84 | image_data = response.json()["image"]
85 | image = Image.open(io.BytesIO(base64.b64decode(image_data)))
86 | assert image.size == (32, 32), f"Unexpected image size: {image.size}"
87 |
--------------------------------------------------------------------------------
/tests/test_torch.py:
--------------------------------------------------------------------------------
1 | # Copyright The Lightning AI team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 | import torch
16 | import torch.nn as nn
17 | from fastapi import Request, Response
18 | from fastapi.testclient import TestClient
19 |
20 | from litserve import LitAPI, LitServer
21 | from litserve.utils import wrap_litserve_start
22 |
23 |
24 | class Linear(nn.Module):
25 | def __init__(self):
26 | super().__init__()
27 | self.linear = nn.Linear(1, 1)
28 | self.linear.weight.data.fill_(2.0)
29 | self.linear.bias.data.fill_(1.0)
30 |
31 | def forward(self, x):
32 | return self.linear(x)
33 |
34 |
35 | class SimpleLitAPI(LitAPI):
36 | def setup(self, device):
37 | self.model = Linear().to(device)
38 | self.device = device
39 |
40 | def decode_request(self, request: Request):
41 | content = request["input"]
42 | return torch.tensor([content], device=self.device)
43 |
44 | def predict(self, x):
45 | return self.model(x[None, :])
46 |
47 | def encode_response(self, output) -> Response:
48 | return {"output": float(output)}
49 |
50 |
51 | def test_torch():
52 | server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=10)
53 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
54 | response = client.post("/predict", json={"input": 4.0})
55 | assert response.json() == {"output": 9.0}
56 |
57 |
58 | @pytest.mark.skipif(torch.cuda.device_count() == 0, reason="requires CUDA to be available")
59 | def test_torch_gpu():
60 | server = LitServer(SimpleLitAPI(), accelerator="cuda", devices=1, timeout=10)
61 | with wrap_litserve_start(server) as server, TestClient(server.app) as client:
62 | response = client.post("/predict", json={"input": 4.0})
63 | assert response.json() == {"output": 9.0}
64 |
--------------------------------------------------------------------------------
/tests/test_transport.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import multiprocessing as mp
3 | from queue import Empty
4 | from unittest.mock import MagicMock, patch
5 |
6 | import pytest
7 |
8 | from litserve.transport.factory import TransportConfig, create_transport_from_config
9 | from litserve.transport.process_transport import MPQueueTransport
10 |
11 |
12 | class TestMPQueueTransport:
13 | @pytest.fixture
14 | def manager(self):
15 | manager = mp.Manager()
16 | yield manager
17 | manager.shutdown()
18 |
19 | @pytest.fixture
20 | def queues(self, manager):
21 | return [manager.Queue() for _ in range(2)]
22 |
23 | @pytest.fixture
24 | def transport(self, manager, queues):
25 | return MPQueueTransport(manager, queues)
26 |
27 | def test_init(self, transport, queues):
28 | """Test that the transport initializes correctly."""
29 | assert transport._queues == queues
30 | assert transport._closed is False
31 |
32 | def test_send(self, transport, queues):
33 | test_item = {"test": "data"}
34 | consumer_id = 0
35 |
36 | transport.send(test_item, consumer_id)
37 |
38 | assert queues[consumer_id].get() == test_item
39 |
40 | def test_send_when_closed(self, transport):
41 | transport._closed = True
42 | test_item = {"test": "data"}
43 | consumer_id = 0
44 |
45 | result = transport.send(test_item, consumer_id)
46 |
47 | assert result is None
48 |
49 | @pytest.mark.asyncio
50 | async def test_areceive(self, transport, queues):
51 | test_item = {"test": "data"}
52 | consumer_id = 0
53 | queues[consumer_id].put(test_item)
54 |
55 | result = await transport.areceive(consumer_id)
56 |
57 | assert result == test_item
58 |
59 | @pytest.mark.asyncio
60 | async def test_areceive_when_closed(self, transport):
61 | transport._closed = True
62 | consumer_id = 0
63 |
64 | with pytest.raises(asyncio.CancelledError, match="Transport closed"):
65 | await transport.areceive(consumer_id)
66 |
67 | @pytest.mark.asyncio
68 | async def test_areceive_timeout(self, transport):
69 | consumer_id = 0
70 | timeout = 0.1
71 |
72 | with pytest.raises(Empty):
73 | await transport.areceive(consumer_id, timeout=timeout)
74 |
75 | @pytest.mark.asyncio
76 | async def test_areceive_cancellation(self, transport):
77 | consumer_id = 0
78 |
79 | with patch("asyncio.to_thread", side_effect=asyncio.CancelledError), pytest.raises(asyncio.CancelledError):
80 | await transport.areceive(consumer_id)
81 |
82 | def test_close(self, transport, queues):
83 | transport.close()
84 |
85 | assert transport._closed is True
86 | for queue in queues:
87 | assert queue.get() is None
88 |
89 | def test_reduce(self, transport, queues):
90 | cls, args = transport.__reduce__()
91 |
92 | assert cls == MPQueueTransport
93 | assert args == (None, queues)
94 |
95 |
96 | class TestTransportFactory:
97 | @pytest.fixture
98 | def mock_manager(self):
99 | return MagicMock()
100 |
101 | def test_create_mp_transport(self, mock_manager):
102 | config_dict = {"transport_type": "mp", "num_consumers": 2}
103 | config = TransportConfig(**config_dict)
104 |
105 | with patch("litserve.transport.factory._create_mp_transport") as mock_create:
106 | mock_create.return_value = MPQueueTransport(mock_manager, [MagicMock(), MagicMock()])
107 |
108 | transport = create_transport_from_config(config)
109 |
110 | assert isinstance(transport, MPQueueTransport)
111 | mock_create.assert_called_once()
112 |
113 | def test_create_transport_invalid_type(self):
114 | with patch("litserve.transport.factory.TransportConfig.model_validate") as mock_validate:
115 | mock_validate.return_value = MagicMock(transport_type="invalid")
116 |
117 | with pytest.raises(ValueError, match="Invalid transport type"):
118 | create_transport_from_config(mock_validate.return_value)
119 |
120 |
121 | @pytest.mark.integration
122 | class TestTransportIntegration:
123 | """Integration tests for the transport system."""
124 |
125 | @pytest.fixture
126 | def mock_transport(self):
127 | transport = MagicMock()
128 | transport._closed = False
129 | transport._waiting_tasks = []
130 |
131 | transport.send = MagicMock()
132 |
133 | async def mock_areceive(consumer_id, timeout=None, block=True):
134 | current_task = asyncio.current_task()
135 | transport._waiting_tasks.append(current_task)
136 |
137 | try:
138 | if transport._closed:
139 | raise asyncio.CancelledError("Transport closed")
140 |
141 | await asyncio.sleep(10) # Long sleep to ensure we'll be cancelled
142 |
143 | # This should only be reached if not cancelled
144 | return ("test_id", {"test": "data"})
145 | finally:
146 | # Clean up task reference
147 | if current_task in transport._waiting_tasks:
148 | transport._waiting_tasks.remove(current_task)
149 |
150 | transport.areceive = mock_areceive
151 |
152 | def mock_close():
153 | transport._closed = True
154 | for task in transport._waiting_tasks:
155 | task.cancel()
156 |
157 | transport.close = mock_close
158 |
159 | return transport
160 |
161 | @pytest.mark.asyncio
162 | async def test_send_receive_cycle(self, mock_transport):
163 | """Test a complete send-receive cycle."""
164 | # Arrange
165 | test_item = ("test_id", {"test": "data"})
166 | consumer_id = 0
167 |
168 | # Act - Send
169 | mock_transport.send(test_item, consumer_id)
170 |
171 | # Act - Receive
172 | result = await mock_transport.areceive(consumer_id)
173 |
174 | # Assert
175 | assert result == test_item
176 | mock_transport.send.assert_called_once_with(test_item, consumer_id)
177 |
178 | @pytest.mark.asyncio
179 | async def test_shutdown_sequence(self, mock_transport):
180 | """Test the shutdown sequence works correctly."""
181 | # Arrange
182 | consumer_id = 0
183 |
184 | async def receive_task():
185 | try:
186 | await mock_transport.areceive(consumer_id)
187 | return False # Should not reach here if cancelled
188 | except asyncio.CancelledError:
189 | return True # Successfully cancelled
190 |
191 | task = asyncio.create_task(receive_task())
192 | await asyncio.sleep(0.1)
193 |
194 | # Act
195 | mock_transport.close()
196 | result = await asyncio.wait_for(task, timeout=2.0)
197 |
198 | # Assert
199 | assert result is True
200 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 | import sys
5 | from unittest import mock
6 | from unittest.mock import MagicMock
7 |
8 | import pytest
9 | from fastapi import HTTPException
10 |
11 | from litserve.utils import (
12 | call_after_stream,
13 | configure_logging,
14 | dump_exception,
15 | generate_random_zmq_address,
16 | set_trace_if_debug,
17 | )
18 |
19 |
20 | def test_dump_exception():
21 | e1 = dump_exception(HTTPException(status_code=404, detail="Not Found"))
22 | assert isinstance(e1, bytes)
23 |
24 | exc = HTTPException(400, "Custom Lit error")
25 | isinstance(pickle.loads(dump_exception(exc)), HTTPException)
26 | assert pickle.loads(dump_exception(exc)).detail == "Custom Lit error"
27 | assert pickle.loads(dump_exception(exc)).status_code == 400
28 |
29 |
30 | async def dummy_streamer():
31 | for i in range(10):
32 | yield i
33 |
34 |
35 | @pytest.mark.asyncio
36 | async def test_call_after_stream():
37 | callback = MagicMock()
38 | callback.return_value = None
39 | streamer = dummy_streamer()
40 | async for _ in call_after_stream(streamer, callback, "first_arg", random_arg="second_arg"):
41 | pass
42 | callback.assert_called()
43 | callback.assert_called_with("first_arg", random_arg="second_arg")
44 |
45 |
46 | @pytest.mark.skipif(sys.platform == "win32", reason="This test is for non-Windows platforms only.")
47 | def test_generate_random_zmq_address_non_windows(tmpdir):
48 | """Test generate_random_zmq_address on non-Windows platforms."""
49 |
50 | temp_dir = str(tmpdir)
51 | address1 = generate_random_zmq_address(temp_dir=temp_dir)
52 | address2 = generate_random_zmq_address(temp_dir=temp_dir)
53 |
54 | assert address1.startswith("ipc://"), "Address should start with 'ipc://'"
55 | assert address2.startswith("ipc://"), "Address should start with 'ipc://'"
56 | assert address1 != address2, "Addresses should be unique"
57 |
58 | # Verify the path exists within the specified temp_dir
59 | assert os.path.commonpath([temp_dir, address1[6:]]) == temp_dir
60 | assert os.path.commonpath([temp_dir, address2[6:]]) == temp_dir
61 |
62 |
63 | def test_configure_logging():
64 | configure_logging(use_rich=False)
65 | assert logging.getLogger("litserve").handlers[0].__class__.__name__ == "StreamHandler"
66 |
67 |
68 | def test_configure_logging_rich_not_installed():
69 | # patch builtins.__import__ to raise ImportError
70 | with mock.patch("builtins.__import__", side_effect=ImportError):
71 | configure_logging(use_rich=True)
72 | assert logging.getLogger("litserve").handlers[0].__class__.__name__ == "StreamHandler"
73 |
74 |
75 | @mock.patch("litserve.utils.set_trace")
76 | def test_set_trace_if_debug(mock_set_trace):
77 | # mock environ
78 | with mock.patch("litserve.utils.os.environ", {"LITSERVE_DEBUG": "1"}):
79 | set_trace_if_debug()
80 | mock_set_trace.assert_called_once()
81 |
82 |
83 | @mock.patch("litserve.utils.ForkedPdb")
84 | def test_set_trace_if_debug_not_set(mock_forked_pdb):
85 | with mock.patch("litserve.utils.os.environ", {"LITSERVE_DEBUG": "0"}):
86 | set_trace_if_debug()
87 | mock_forked_pdb.assert_not_called()
88 |
--------------------------------------------------------------------------------
/tests/test_zmq_queue.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import pickle
3 | from queue import Empty
4 | from unittest.mock import AsyncMock, Mock, patch
5 |
6 | import pytest
7 | import zmq
8 |
9 | from litserve.transport.zmq_queue import AsyncConsumer, Broker, Producer
10 |
11 |
12 | @pytest.fixture
13 | def mock_context():
14 | with patch("zmq.Context") as mock_ctx:
15 | socket = Mock()
16 | mock_ctx.return_value.socket.return_value = socket
17 | yield mock_ctx, socket
18 |
19 |
20 | @pytest.fixture
21 | def mock_async_context():
22 | with patch("zmq.asyncio.Context") as mock_ctx:
23 | socket = AsyncMock() # Use AsyncMock for async methods
24 | mock_ctx.return_value.socket.return_value = socket
25 | yield mock_ctx, socket
26 |
27 |
28 | def test_broker_start_stop(mock_context):
29 | _, socket = mock_context
30 | broker = Broker(use_process=False)
31 |
32 | # Start broker
33 | broker.start()
34 | assert socket.bind.call_count == 2 # Should bind both frontend and backend
35 |
36 | # Stop broker
37 | broker.stop()
38 | assert socket.close.call_count == 2 # Should close both sockets
39 |
40 |
41 | def test_broker_error_handling(mock_context):
42 | """Test broker handles ZMQ errors."""
43 | _, socket = mock_context
44 | socket.bind.side_effect = zmq.ZMQError("Test error")
45 |
46 | broker = Broker(use_process=False)
47 | broker.start()
48 | broker.stop()
49 |
50 | assert socket.close.called # Should clean up even on error
51 |
52 |
53 | def test_producer_send(mock_context):
54 | _, socket = mock_context
55 | producer = Producer(address="test_addr")
56 |
57 | # Test sending simple data
58 | producer.put("test_data", consumer_id=1)
59 | sent_message = socket.send.call_args[0][0]
60 | consumer_id, data = sent_message.split(b"|", 1)
61 | assert consumer_id == b"1"
62 | assert pickle.loads(data) == "test_data"
63 |
64 | # Test sending complex data
65 | complex_data = {"key": [1, 2, 3]}
66 | producer.put(complex_data, consumer_id=2)
67 | sent_message = socket.send.call_args[0][0]
68 | consumer_id, data = sent_message.split(b"|", 1)
69 | assert consumer_id == b"2"
70 | assert pickle.loads(data) == complex_data
71 |
72 |
73 | def test_producer_error_handling(mock_context):
74 | _, socket = mock_context
75 | producer = Producer(address="test_addr")
76 |
77 | # Test ZMQ error
78 | socket.send.side_effect = zmq.ZMQError("Test error")
79 | with pytest.raises(zmq.ZMQError):
80 | producer.put("data", consumer_id=1)
81 |
82 | # Test unpickleable object
83 | class Unpickleable:
84 | def __reduce__(self):
85 | raise pickle.PickleError("Can't pickle this!")
86 |
87 | with pytest.raises(pickle.PickleError):
88 | producer.put(Unpickleable(), consumer_id=1)
89 |
90 |
91 | def test_producer_wait_for_subscribers(mock_context):
92 | _, socket = mock_context
93 | producer = Producer(address="test_addr")
94 |
95 | # Test successful wait
96 | assert producer.wait_for_subscribers(timeout=0.1)
97 | assert socket.send.called
98 |
99 | # Test timeout
100 | socket.send.side_effect = zmq.ZMQError("Would block")
101 | assert not producer.wait_for_subscribers(timeout=0.1)
102 |
103 |
104 | @pytest.mark.parametrize("timeout", [1.0, None])
105 | @pytest.mark.asyncio
106 | async def test_async_consumer(mock_async_context, timeout):
107 | _, socket = mock_async_context
108 | consumer = AsyncConsumer(consumer_id=1, address="test_addr")
109 |
110 | # Setup mock received data
111 | test_data = {"test": "data"}
112 | message = b"1|" + pickle.dumps(test_data)
113 | socket.recv.return_value = message
114 |
115 | # Test receiving
116 | received = await consumer.get(timeout=timeout)
117 | assert received == test_data
118 |
119 | # Test timeout
120 | socket.recv.side_effect = asyncio.TimeoutError()
121 | with pytest.raises(Empty):
122 | await consumer.get(timeout=timeout)
123 |
124 |
125 | @pytest.mark.asyncio
126 | async def test_async_consumer_cleanup():
127 | with patch("zmq.asyncio.Context") as mock_ctx:
128 | socket = AsyncMock()
129 | mock_ctx.return_value.socket.return_value = socket
130 |
131 | consumer = AsyncConsumer(consumer_id=1, address="test_addr")
132 | consumer.close()
133 |
134 | assert socket.close.called
135 | assert mock_ctx.return_value.term.called
136 |
137 |
138 | def test_producer_cleanup():
139 | with patch("zmq.Context") as mock_ctx:
140 | socket = Mock()
141 | mock_ctx.return_value.socket.return_value = socket
142 |
143 | producer = Producer(address="test_addr")
144 | producer.close()
145 |
146 | assert socket.close.called
147 | assert mock_ctx.return_value.term.called
148 |
--------------------------------------------------------------------------------