├── .dockerignore
├── .github
└── workflows
│ ├── check_homepage_build.yaml
│ ├── deploy_homepage.yaml
│ ├── publish_pypi.yaml
│ ├── push_docker.yaml
│ └── python_lint.yaml
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── docker
├── dev.Dockerfile
├── eric.Dockerfile
├── gateway.Dockerfile
├── resource-manager.Dockerfile
├── sidecar.Dockerfile
├── task-dispatcher.Dockerfile
├── task-manager.Dockerfile
└── vllm.Dockerfile
├── docs
├── CNAME
├── architecture
│ ├── eric.md
│ ├── index.md
│ ├── sidecar.md
│ └── task.md
├── assets
│ ├── css
│ │ └── extra.css
│ ├── img
│ │ └── favicon.png
│ ├── js
│ │ └── mathjax.js
│ └── video
│ │ └── cornserve.mp4
├── contributor_guide
│ ├── eric.md
│ ├── index.md
│ ├── kubernetes.md
│ ├── sidecar.md
│ └── tracing.md
├── getting_started
│ ├── building_apps.md
│ ├── cornserve.md
│ ├── index.md
│ ├── jupyter.ipynb
│ └── registering_apps.md
├── index.md
└── requirements.txt
├── examples
├── mllm
│ ├── app.py
│ └── requirements.txt
└── notebook.ipynb
├── kubernetes
├── README.md
├── k3s
│ ├── agent-config.yaml
│ ├── registries.yaml
│ └── server-config.yaml
├── kustomize
│ ├── cornserve-system
│ │ ├── base
│ │ │ ├── jaeger
│ │ │ │ ├── configmap.yaml
│ │ │ │ ├── deployment.yaml
│ │ │ │ ├── kustomization.yaml
│ │ │ │ └── service.yaml
│ │ │ ├── kustomization.yaml
│ │ │ └── namespace.yaml
│ │ └── overlays
│ │ │ ├── dev
│ │ │ ├── jaeger
│ │ │ │ ├── kustomization.yaml
│ │ │ │ ├── node-port-service.yaml
│ │ │ │ └── volume-patch.yaml
│ │ │ └── kustomization.yaml
│ │ │ ├── local
│ │ │ ├── jaeger
│ │ │ │ ├── kustomization.yaml
│ │ │ │ ├── node-port-service.yaml
│ │ │ │ └── volume-patch.yaml
│ │ │ └── kustomization.yaml
│ │ │ └── minikube
│ │ │ ├── jaeger
│ │ │ ├── kustomization.yaml
│ │ │ └── node-port-service.yaml
│ │ │ └── kustomization.yaml
│ └── cornserve
│ │ ├── base
│ │ ├── gateway
│ │ │ ├── deployment.yaml
│ │ │ ├── kustomization.yaml
│ │ │ └── service.yaml
│ │ ├── kustomization.yaml
│ │ ├── namespace.yaml
│ │ ├── resource-manager
│ │ │ ├── deployment.yaml
│ │ │ ├── kustomization.yaml
│ │ │ ├── role-binding.yaml
│ │ │ ├── role.yaml
│ │ │ ├── service-account.yaml
│ │ │ └── service.yaml
│ │ ├── sidecar
│ │ │ ├── kustomization.yaml
│ │ │ ├── role-binding.yaml
│ │ │ ├── role.yaml
│ │ │ ├── service-account.yaml
│ │ │ ├── service.yaml
│ │ │ └── statefulset.yaml
│ │ ├── task-dispatcher
│ │ │ ├── deployment.yaml
│ │ │ ├── kustomization.yaml
│ │ │ └── service.yaml
│ │ └── task-manager
│ │ │ ├── kustomization.yaml
│ │ │ ├── role-binding.yaml
│ │ │ ├── role.yaml
│ │ │ └── service-account.yaml
│ │ └── overlays
│ │ ├── dev
│ │ ├── gateway
│ │ │ ├── kustomization.yaml
│ │ │ └── node-port-service.yaml
│ │ ├── image-pull-policy.yaml
│ │ └── kustomization.yaml
│ │ ├── local
│ │ ├── gateway
│ │ │ ├── kustomization.yaml
│ │ │ └── node-port-service.yaml
│ │ ├── image-pull-policy.yaml
│ │ └── kustomization.yaml
│ │ ├── minikube
│ │ ├── gateway
│ │ │ ├── kustomization.yaml
│ │ │ └── node-port-service.yaml
│ │ └── kustomization.yaml
│ │ └── prod
│ │ └── kustomization.yaml
├── registry.sh
└── set_registry.sh
├── mkdocs.yml
├── proto
└── v1
│ ├── common.proto
│ ├── resource_manager.proto
│ ├── sidecar.proto
│ ├── task_dispatcher.proto
│ └── task_manager.proto
├── python
├── cornserve
│ ├── __init__.py
│ ├── app
│ │ ├── __init__.py
│ │ └── base.py
│ ├── cli.py
│ ├── constants.py
│ ├── frontend.py
│ ├── logging.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── gateway
│ │ │ ├── __init__.py
│ │ │ ├── app
│ │ │ │ ├── manager.py
│ │ │ │ └── models.py
│ │ │ ├── entrypoint.py
│ │ │ ├── models.py
│ │ │ ├── router.py
│ │ │ ├── session.py
│ │ │ └── task_manager.py
│ │ ├── resource_manager
│ │ │ ├── __init__.py
│ │ │ ├── entrypoint.py
│ │ │ ├── grpc.py
│ │ │ ├── manager.py
│ │ │ └── resource.py
│ │ ├── sidecar
│ │ │ ├── __init__.py
│ │ │ ├── launch.py
│ │ │ ├── receiver.py
│ │ │ ├── scheduler.py
│ │ │ ├── schema.py
│ │ │ ├── sender.py
│ │ │ ├── server.py
│ │ │ └── shm_manager.py
│ │ ├── task_dispatcher
│ │ │ ├── __init__.py
│ │ │ ├── dispatcher.py
│ │ │ ├── entrypoint.py
│ │ │ ├── grpc.py
│ │ │ └── router.py
│ │ └── task_manager
│ │ │ ├── __init__.py
│ │ │ ├── entrypoint.py
│ │ │ ├── grpc.py
│ │ │ └── manager.py
│ ├── sidecar
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── constants.py
│ │ ├── schema.py
│ │ ├── serde.py
│ │ └── utils.py
│ ├── task
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── builtins
│ │ │ ├── __init__.py
│ │ │ ├── encoder.py
│ │ │ ├── llm.py
│ │ │ └── mllm.py
│ │ ├── forward.py
│ │ └── registry.py
│ ├── task_executors
│ │ ├── __init__.py
│ │ ├── descriptor
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── builtins
│ │ │ │ ├── __init__.py
│ │ │ │ ├── encoder.py
│ │ │ │ └── llm.py
│ │ │ └── registry.py
│ │ └── eric
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── config.py
│ │ │ ├── distributed
│ │ │ ├── parallel.py
│ │ │ └── shm_broadcast.py
│ │ │ ├── engine
│ │ │ ├── __init__.py
│ │ │ ├── client.py
│ │ │ ├── core.py
│ │ │ └── scheduler.py
│ │ │ ├── entrypoint.py
│ │ │ ├── executor
│ │ │ ├── __init__.py
│ │ │ ├── executor.py
│ │ │ ├── loader.py
│ │ │ └── worker.py
│ │ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── gemma3.py
│ │ │ ├── layers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── activations.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── linear.py
│ │ │ │ ├── norm.py
│ │ │ │ ├── siglip.py
│ │ │ │ └── vit.py
│ │ │ ├── llava_onevision.py
│ │ │ ├── qwen2_5_omni.py
│ │ │ ├── qwen2_5_vl.py
│ │ │ ├── qwen2_vl.py
│ │ │ └── registry.py
│ │ │ ├── router
│ │ │ ├── __init__.py
│ │ │ ├── app.py
│ │ │ └── processor.py
│ │ │ ├── schema.py
│ │ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── distributed.py
│ │ │ ├── network.py
│ │ │ ├── package.py
│ │ │ ├── process.py
│ │ │ ├── serde.py
│ │ │ └── zmq.py
│ └── tracing.py
├── pyproject.toml
├── scripts
│ └── lint.sh
└── tests
│ ├── services
│ ├── __init__.py
│ └── sidecar
│ │ ├── __init__.py
│ │ ├── test_sidecar.py
│ │ └── utils.py
│ ├── task
│ ├── __init__.py
│ ├── builtins
│ │ ├── __init__.py
│ │ └── test_mllm.py
│ ├── test_base.py
│ └── test_registry.py
│ └── task_executors
│ └── eric
│ ├── __init__.py
│ ├── engine
│ ├── __init__.py
│ └── test_scheduler.py
│ ├── models
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_gemma3.py
│ ├── test_llava_onevision.py
│ ├── test_qwen2_5_omni.py
│ ├── test_qwen2_5_vl.py
│ └── test_qwen2_vl.py
│ └── utils.py
└── scripts
├── build_export_images.sh
└── generate_pb.sh
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Cornserve repo-specific
2 | docs/
3 | examples/
4 | scratchpad/
5 | scripts/
6 |
7 | # Git
8 | .git
9 | .gitignore
10 | .gitattributes
11 |
12 |
13 | # CI
14 | .codeclimate.yml
15 | .travis.yml
16 | .taskcluster.yml
17 |
18 | # Docker and Kubernetes
19 | docker-compose.yml
20 | Dockerfile
21 | *.Dockerfile
22 | .docker
23 | .dockerignore
24 | kubernetes/
25 | docker/
26 |
27 | # Byte-compiled / optimized / DLL / type stub files
28 | **/__pycache__/
29 | **/*.py[codi]
30 |
31 | # C extensions
32 | *.so
33 |
34 | # Distribution / packaging
35 | .Python
36 | env/
37 | build/
38 | develop-eggs/
39 | dist/
40 | downloads/
41 | eggs/
42 | lib/
43 | lib64/
44 | parts/
45 | sdist/
46 | var/
47 | *.egg-info/
48 | .installed.cfg
49 | *.egg
50 |
51 | # PyInstaller
52 | # Usually these files are written by a python script from a template
53 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
54 | *.manifest
55 | *.spec
56 |
57 | # Installer logs
58 | pip-log.txt
59 | pip-delete-this-directory.txt
60 |
61 | # Unit test / coverage reports
62 | htmlcov/
63 | .tox/
64 | .coverage
65 | .cache
66 | nosetests.xml
67 | coverage.xml
68 |
69 | # Translations
70 | *.mo
71 | *.pot
72 |
73 | # Django stuff:
74 | *.log
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Virtual environment
83 | .env
84 | .envrc
85 | .venv/
86 | venv/
87 |
88 | # PyCharm
89 | .idea
90 |
91 | # Python mode for VIM
92 | .ropeproject
93 | **/.ropeproject
94 |
95 | # Vim swap files
96 | **/*.swp
97 |
98 | # VS Code
99 | .vscode/
100 |
101 | # Protobuf
102 | proto/
103 | *.proto
104 |
105 | # Text files
106 | *.md
107 |
--------------------------------------------------------------------------------
/.github/workflows/check_homepage_build.yaml:
--------------------------------------------------------------------------------
1 | name: Check homepage build
2 |
3 | on:
4 | pull_request:
5 | paths:
6 | - 'examples/**'
7 | - 'docs/**'
8 | - 'mkdocs.yml'
9 | - '.github/workflows/check_homepage_build.yaml'
10 |
11 | concurrency:
12 | group: ${{ github.ref }}-check-homepage-build
13 | cancel-in-progress: true
14 |
15 | jobs:
16 | check:
17 | runs-on: ubuntu-latest
18 | if: github.event.repository.fork == false
19 | steps:
20 | - name: Checkout repository
21 | uses: actions/checkout@v4
22 | - name: Setup Python
23 | uses: actions/setup-python@v5
24 | with:
25 | python-version: 3.11
26 | cache: 'pip'
27 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
28 | - uses: actions/cache@v4
29 | with:
30 | key: mkdocs-material-${{ env.cache_id }}
31 | path: .cache
32 | restore-keys: |
33 | mkdocs-material-
34 | - name: Install homepage build dependencies
35 | run: pip install -r docs/requirements.txt
36 | - name: Build homepage
37 | run: mkdocs build --verbose --strict
38 | env:
39 | BUILD_SOCIAL_CARD: true
40 |
--------------------------------------------------------------------------------
/.github/workflows/deploy_homepage.yaml:
--------------------------------------------------------------------------------
1 | name: Deploy homepage
2 | on:
3 | push:
4 | branches:
5 | - master
6 | paths:
7 | - 'examples/**'
8 | - 'docs/**'
9 | - 'mkdocs.yml'
10 | - '.github/workflows/deploy_homepage.yaml'
11 |
12 | env:
13 | SITE_ANALYTICS: google
14 |
15 | jobs:
16 | deploy:
17 | runs-on: ubuntu-latest
18 | if: github.repository_owner == 'cornserve-ai'
19 | steps:
20 | - name: Checkout repository
21 | uses: actions/checkout@v4
22 | - name: Setup Python
23 | uses: actions/setup-python@v5
24 | with:
25 | python-version: 3.11
26 | cache: 'pip'
27 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
28 | - uses: actions/cache@v4
29 | with:
30 | key: mkdocs-material-${{ env.cache_id }}
31 | path: .cache
32 | restore-keys: |
33 | mkdocs-material-
34 | - name: Install homepage build dependencies
35 | run: pip install -r docs/requirements.txt
36 | - name: Build and deploy homepage
37 | run: mkdocs gh-deploy --force
38 | env:
39 | BUILD_SOCIAL_CARD: true
40 |
--------------------------------------------------------------------------------
/.github/workflows/publish_pypi.yaml:
--------------------------------------------------------------------------------
1 | name: Publish Python package to PyPI
2 |
3 | on:
4 | push:
5 | tags:
6 | - v*
7 |
8 | jobs:
9 | pypi-publish:
10 | runs-on: ubuntu-latest
11 | if: github.repository_owner == 'cornserve-ai'
12 | permissions:
13 | id-token: write
14 | steps:
15 | - name: Checkout repository
16 | uses: actions/checkout@v4
17 | with:
18 | submodules: recursive
19 | token: ${{ secrets.SUBMODULE_TOKEN }}
20 | - name: Setup Python
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: 3.11
24 | cache: 'pip'
25 | - name: Install protobuf dependencies
26 | run: pip install grpcio-tools
27 | - name: Generate protobuf files
28 | run: bash scripts/generate_pb.sh
29 | - name: Build source distribution
30 | run: cd python && pip install build && python -m build
31 | - name: Publish to PyPI
32 | uses: pypa/gh-action-pypi-publish@release/v1
33 | with:
34 | packages-dir: python/dist/
35 |
36 |
--------------------------------------------------------------------------------
/.github/workflows/push_docker.yaml:
--------------------------------------------------------------------------------
1 | name: Push Docker image
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | tags:
8 | - v*
9 | paths:
10 | - '.github/workflows/push_docker.yaml'
11 | - '.gitmodules'
12 | - 'docker/**'
13 | - 'python/**'
14 | - 'proto/**'
15 | - 'third_party/**'
16 | - 'scripts/generate_pb.sh'
17 | - '.dockerignore'
18 | - 'LICENSE'
19 | - 'setup.py'
20 | - 'pyproject.toml'
21 |
22 | env:
23 | NAMESPACE: cornserve
24 |
25 | jobs:
26 | build_and_push:
27 | if: github.repository_owner == 'cornserve-ai'
28 | runs-on: ${{ (matrix.component == 'vllm' || matrix.component == 'eric') && 'blacksmith-2vcpu-ubuntu-2404' || 'ubuntu-latest' }}
29 | strategy:
30 | matrix:
31 | component:
32 | - sidecar
33 | - task-dispatcher
34 | - task-manager
35 | - resource-manager
36 | - gateway
37 | - eric
38 | - vllm
39 | include:
40 | - component: eric
41 | build_args: |
42 | max_jobs=4
43 | build_target: eric
44 | - component: vllm
45 | build_target: vllm
46 |
47 | steps:
48 | - name: Remove unnecessary files
49 | run: |
50 | sudo rm -rf /usr/share/dotnet
51 | sudo rm -rf /opt/ghc
52 | sudo rm -rf "/usr/local/share/boost"
53 | sudo rm -rf "$AGENT_TOOLSDIRECTORY"
54 |
55 | - name: Checkout repository
56 | uses: actions/checkout@v4
57 | with:
58 | submodules: recursive
59 |
60 | - name: Setup Python for protobuf generation
61 | uses: actions/setup-python@v5
62 | with:
63 | python-version: 3.11
64 |
65 | - name: Install protobuf dependencies
66 | run: pip install grpcio-tools==1.71.0
67 |
68 | - name: Generate protobuf files
69 | run: bash scripts/generate_pb.sh
70 |
71 | - name: Docker Hub login
72 | uses: docker/login-action@v3
73 | with:
74 | username: ${{ secrets.DOCKER_HUB_USERNAME }}
75 | password: ${{ secrets.DOCKER_HUB_TOKEN }}
76 |
77 | - name: Set up Docker Buildx
78 | uses: docker/setup-buildx-action@v3
79 |
80 | - name: Generate image metadata
81 | id: meta
82 | uses: docker/metadata-action@v5
83 | with:
84 | images: |
85 | ${{ env.NAMESPACE }}/${{ matrix.component }}
86 | tags: |
87 | type=raw,value=latest,enable={{is_default_branch}}
88 | type=match,pattern=v.*
89 |
90 | - name: Build and push Docker image
91 | uses: docker/build-push-action@v6
92 | with:
93 | context: .
94 | file: docker/${{ matrix.component }}.Dockerfile
95 | push: true
96 | tags: ${{ steps.meta.outputs.tags }}
97 | labels: ${{ steps.meta.outputs.labels }}
98 | cache-from: type=registry,ref=${{ env.NAMESPACE }}/${{ matrix.component }}:buildcache
99 | cache-to: type=registry,ref=${{ env.NAMESPACE }}/${{ matrix.component }}:buildcache,mode=max
100 | platforms: linux/amd64
101 | build-args: ${{ matrix.build_args }}
102 | target: ${{ matrix.build_target || '' }}
103 |
--------------------------------------------------------------------------------
/.github/workflows/python_lint.yaml:
--------------------------------------------------------------------------------
1 | name: Python format and lint check
2 |
3 | on:
4 | pull_request:
5 | paths:
6 | - '.github/workflows/python_lint.yaml'
7 | - 'python/**'
8 | - 'proto/**'
9 | - 'pyproject.toml'
10 | push:
11 | paths:
12 | - '.github/workflows/python_lint.yaml'
13 | - 'python/**'
14 | - 'proto/**'
15 | - 'pyproject.toml'
16 |
17 | # Jobs initiated by previous pushes get cancelled by a new push.
18 | concurrency:
19 | group: ${{ github.ref }}-python-format-and-lint
20 | cancel-in-progress: true
21 |
22 | jobs:
23 | format_lint:
24 | if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name != github.repository }}
25 | runs-on: ubuntu-latest
26 | steps:
27 | - name: Checkout repository
28 | uses: actions/checkout@v4
29 | - name: Setup Python
30 | uses: actions/setup-python@v5
31 | with:
32 | python-version: 3.11
33 | cache: 'pip'
34 | - name: Install library and lint tools
35 | run: pip install -U pip && pip install "./python[dev-no-gpu]"
36 | - name: Generate protobuf files
37 | run: bash scripts/generate_pb.sh
38 | - name: Check format and lint
39 | run: bash python/scripts/lint.sh
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # UV
99 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | #uv.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
117 | .pdm.toml
118 | .pdm-python
119 | .pdm-build/
120 |
121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122 | __pypackages__/
123 |
124 | # Celery stuff
125 | celerybeat-schedule
126 | celerybeat.pid
127 |
128 | # SageMath parsed files
129 | *.sage.py
130 |
131 | # Environments
132 | .env
133 | .envrc
134 | .venv
135 | env/
136 | venv/
137 | ENV/
138 | env.bak/
139 | venv.bak/
140 |
141 | # Spyder project settings
142 | .spyderproject
143 | .spyproject
144 |
145 | # Rope project settings
146 | .ropeproject
147 |
148 | # mkdocs documentation
149 | /site
150 |
151 | # mypy
152 | .mypy_cache/
153 | .dmypy.json
154 | dmypy.json
155 |
156 | # Pyre type checker
157 | .pyre/
158 |
159 | # pytype static type analyzer
160 | .pytype/
161 |
162 | # Cython debug symbols
163 | cython_debug/
164 |
165 | # PyCharm
166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
168 | # and can be added to the global gitignore or merged into this file. For a more nuclear
169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
170 | #.idea/
171 |
172 | # Ruff stuff:
173 | .ruff_cache/
174 |
175 | # PyPI configuration file
176 | .pypirc
177 |
178 | # Protobuf
179 | *pb2.py
180 | *pb2.pyi
181 | *pb2_grpc.py
182 |
183 | # Workspace
184 | /old
185 | /tmp
186 | /dev
187 |
188 | # MacOS
189 | .DS_Store
190 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/vllm"]
2 | path = third_party/vllm
3 | url = https://github.com/cornserve-ai/vllm.git
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Cornserve: Easy, Fast, and Scalable Multimodal AI
3 |
4 | [](https://hub.docker.com/r/cornserve/gateway)
5 | [](https://cornserve.ai/)
6 | [](/LICENSE)
7 |
8 |
9 | https://github.com/user-attachments/assets/6dd12ad6-2307-4457-ae70-96e1cecc5ece
10 |
11 | Cornserve is an execution platform for multimodal AI.
12 | Split complex models into smaller separately scalable components (**model fission**) and share common components across multiple applications (**sharing**), all on your own infrastructure.
13 |
14 |
15 | ## Getting Started
16 |
17 | You can quickly try out Cornserve on top of Minikube. Check out our [getting started guide](https://cornserve.ai/getting_started/)!
18 |
19 | Cornserve can be deployed on Kubernetes with a single command. More on our [docs](https://cornserve.ai/getting_started/).
20 |
21 |
22 | ## Contributing
23 |
24 | Cornserve is an open-source project, and we welcome contributions!
25 | Please check out our [contributor guide](https://cornserve.ai/contributor_guide/) for more information on how to get started.
26 |
--------------------------------------------------------------------------------
/docker/dev.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-devel
2 |
3 | RUN apt-get update && apt-get upgrade -y
4 | RUN apt-get install wget build-essential librdmacm-dev net-tools -y
5 |
6 | ########### Install UCX 1.18.0 ###########
7 | RUN wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz
8 | RUN tar xzf ucx-1.18.0.tar.gz
9 | WORKDIR /workspace/ucx-1.18.0
10 | RUN mkdir build
11 | RUN cd build && \
12 | ../configure --build=x86_64-unknown-linux-gnu --host=x86_64-unknown-linux-gnu --program-prefix= --disable-dependency-tracking \
13 | --prefix=/usr --exec-prefix=/usr --bindir=/usr/bin --sbindir=/usr/sbin --sysconfdir=/etc --datadir=/usr/share --includedir=/usr/include \
14 | --libdir=/usr/lib64 --libexecdir=/usr/libexec --localstatedir=/var --sharedstatedir=/var/lib --mandir=/usr/share/man --infodir=/usr/share/info \
15 | --disable-logging --disable-debug --disable-assertions --enable-mt --disable-params-check --without-go --without-java --enable-cma \
16 | --with-verbs --with-mlx5 --with-rdmacm --without-rocm --with-xpmem --without-fuse3 --without-ugni --without-mad --without-ze && \
17 | make -j$(nproc) && make install
18 |
19 | ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
20 | ENV LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
21 |
22 | # UCX logging
23 | ENV UCX_LOG_LEVEL=trace
24 | # UCX_LOG_LEVEL to be one of: fatal, error, warn, info, debug, trace, req, data, async, func, poll
25 |
26 | # UCX transports
27 | ENV UCX_TLS=rc,ib,tcp
28 | ########### End Install UCX ###########
29 |
30 | ADD . /workspace/cornserve
31 | WORKDIR /workspace/cornserve/python
32 |
33 | RUN pip install -e '.[dev]'
34 |
35 | # UCXX logging
36 | ENV UCXPY_LOG_LEVEL=DEBUG
37 | # python log level syntax: DEBUG, INFO, WARNING, ERROR, CRITICAL
38 |
39 | # Disable OpenTelemetry
40 | ENV OTEL_SDK_DISABLED=true
41 |
42 | CMD ["bash"]
43 |
--------------------------------------------------------------------------------
/docker/eric.Dockerfile:
--------------------------------------------------------------------------------
1 | # Build flash-attn wheel inside the `devel` image which has `nvcc`.
2 | FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-devel AS builder
3 |
4 | ARG max_jobs=64
5 | ENV MAX_JOBS=${max_jobs}
6 | ENV NVCC_THREADS=8
7 | RUN pip wheel -w /tmp/wheels --no-build-isolation --no-deps --verbose flash-attn
8 |
9 | # Actual Eric runs inside the `runtime` image. Just copy over the flash-attn wheel.
10 | FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime AS eric
11 |
12 | COPY --from=builder /tmp/wheels/*.whl /tmp/wheels/
13 | RUN pip install --no-cache-dir /tmp/wheels/*.whl && rm -rf /tmp/wheels
14 |
15 | RUN apt-get update \
16 | && apt-get install -y --no-install-recommends build-essential \
17 | && rm -rf /var/lib/apt/lists/*
18 |
19 | ADD . /workspace/cornserve
20 |
21 | WORKDIR /workspace/cornserve/python
22 | RUN pip install -e '.[eric]'
23 |
24 | ENTRYPOINT ["python", "-u", "-m", "cornserve.task_executors.eric.entrypoint"]
25 |
26 | # Eric that has audio support.
27 | FROM eric AS eric-audio
28 |
29 | RUN pip install -e '.[audio]'
30 |
--------------------------------------------------------------------------------
/docker/gateway.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11.11
2 |
3 | ADD . /workspace/cornserve
4 |
5 | WORKDIR /workspace/cornserve/python
6 | RUN pip install -e .[gateway]
7 |
8 | ENTRYPOINT ["python", "-m", "cornserve.services.gateway.entrypoint"]
9 |
--------------------------------------------------------------------------------
/docker/resource-manager.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11.11
2 |
3 | ADD . /workspace/cornserve
4 |
5 | WORKDIR /workspace/cornserve/python
6 | RUN pip install -e .[resource-manager]
7 |
8 | ENTRYPOINT ["python", "-m", "cornserve.services.resource_manager.entrypoint"]
9 |
--------------------------------------------------------------------------------
/docker/sidecar.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime
2 |
3 | RUN apt-get update && apt-get upgrade -y
4 | RUN apt-get install wget build-essential librdmacm-dev net-tools -y
5 |
6 | ########### Install UCX 1.18.0 ###########
7 | RUN wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz
8 | RUN tar xzf ucx-1.18.0.tar.gz
9 | WORKDIR /workspace/ucx-1.18.0
10 | RUN mkdir build
11 | RUN cd build && \
12 | ../configure --build=x86_64-unknown-linux-gnu --host=x86_64-unknown-linux-gnu --program-prefix= --disable-dependency-tracking \
13 | --prefix=/usr --exec-prefix=/usr --bindir=/usr/bin --sbindir=/usr/sbin --sysconfdir=/etc --datadir=/usr/share --includedir=/usr/include \
14 | --libdir=/usr/lib64 --libexecdir=/usr/libexec --localstatedir=/var --sharedstatedir=/var/lib --mandir=/usr/share/man --infodir=/usr/share/info \
15 | --disable-logging --disable-debug --disable-assertions --enable-mt --disable-params-check --without-go --without-java --enable-cma \
16 | --with-verbs --with-mlx5 --with-rdmacm --without-rocm --with-xpmem --without-fuse3 --without-ugni --without-mad --without-ze && \
17 | make -j$(nproc) && make install
18 |
19 | ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
20 | ENV LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
21 | ########### End Install UCX ###########
22 |
23 | ADD . /workspace/cornserve
24 |
25 | WORKDIR /workspace/cornserve/python
26 | RUN pip install -e '.[sidecar]'
27 |
28 | ENTRYPOINT ["python", "-u", "-m", "cornserve.services.sidecar.server"]
29 |
--------------------------------------------------------------------------------
/docker/task-dispatcher.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11.11
2 |
3 | ADD . /workspace/cornserve
4 |
5 | WORKDIR /workspace/cornserve/python
6 | RUN pip install -e .[task-dispatcher]
7 |
8 | ENTRYPOINT ["python", "-m", "cornserve.services.task_dispatcher.entrypoint"]
9 |
--------------------------------------------------------------------------------
/docker/task-manager.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11.11
2 |
3 | ADD . /workspace/cornserve
4 |
5 | WORKDIR /workspace/cornserve/python
6 | RUN pip install -e .[task-manager]
7 |
8 | ENTRYPOINT ["python", "-m", "cornserve.services.task_manager.entrypoint"]
9 |
--------------------------------------------------------------------------------
/docker/vllm.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.8.1-devel-ubuntu22.04 AS base
2 |
3 | RUN apt-get update -y \
4 | && apt-get install -y git curl wget \
5 | && curl -LsSf https://astral.sh/uv/install.sh | sh
6 |
7 | ENV PATH="/root/.local/bin:$PATH"
8 | ENV VIRTUAL_ENV="/opt/venv"
9 | RUN uv venv --python 3.11 --seed ${VIRTUAL_ENV}
10 | ENV PATH="$VIRTUAL_ENV/bin:$PATH"
11 |
12 | ADD . /workspace/cornserve
13 | WORKDIR /workspace/cornserve/third_party/vllm
14 |
15 | # Install CORNSERVE sidecars
16 | RUN cd ../.. && uv pip install './python[sidecar-api]'
17 |
18 | RUN uv pip install -r requirements/common.txt
19 | RUN uv pip install -r requirements/cuda.txt
20 | ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.0.1.dev
21 |
22 | ENV VLLM_USE_PRECOMPILED=1
23 | ENV VLLM_COMMIT=6b6d4961147220fb80f9cc7dcb74db478f9c9a23
24 | ENV VLLM_PRECOMPILED_WHEEL_LOCATION=https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
25 |
26 | # Intermediate vllm stage without audio
27 | FROM base AS vllm
28 | RUN uv pip install -e .
29 |
30 | ENV VLLM_USE_V1=1
31 | ENV HF_HOME="/root/.cache/huggingface"
32 | ENTRYPOINT ["vllm", "serve"]
33 |
34 | # Default final stage with audio support
35 | FROM vllm AS vllm-audio
36 | RUN uv pip install -e .[audio]
37 |
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | cornserve.ai
2 |
--------------------------------------------------------------------------------
/docs/architecture/eric.md:
--------------------------------------------------------------------------------
1 | # Eric: Multimodal Data Embedding Server
2 |
3 | > **Mosharaf**: Hey, what is this "Eric" thing in the architecture diagram?
4 | **Jae-Won**: Oh, uh no, it says "Enc." For Encoder.
5 | **Mosharaf**: Oh.
6 | **Jae-Won**: Now it's Eric.
7 |
8 | Package: `cornserve.task_executors.eric`
9 |
10 | Eric is a multimodal data embedding server that takes in a list of multimodal data (e.g., images, videos) and computes the multimodal embedding of the input data.
11 |
12 | ## Architecture
13 |
14 | Below, components are divided at the process boundary.
15 |
16 | ### Router
17 |
18 | The gateway router is an async FastAPI server that (1) receives modality encoding requests and (2) preprocesses modality data before running the encoder.
19 | Preprocessing is done asynchronously in a thread pool by the `eric.router.processor.Processor` class.
20 |
21 | Each model processes different modality data differently, so the router must instantiate the correct model-specific preprocessor.
22 | Instantiating and invoking these model- and modality-specific preprocessors are implemented in the class `eric.models.[model_module].ModalityProcessor`, which is a subclass of `eric.models.base.BaseModalityProcessor`.
23 |
24 | When modality preprocessing is complete, the router submits the embedding request to the engine.
25 | The router and the engine communicate through ZMQ sockets. Especially, the router holds an instance of the engine client (`eric.engine.client.EngineClient`), which is used to send requests to the engine and receive responses.
26 |
27 | ### Engine
28 |
29 | From the engine and below, everything is synchronous Python (i.e., not `asyncio`).
30 |
31 | The Engine constantly receives embedding requests from the router, runs the request scheduler to create a `eric.schema.Batch`, and invokes the model executor (`eric.executor.executor.ModelExecutor`) to compute the multimodal embedding.
32 | The model executor provides the `execute_model` method, which broadcasts input batch data to all Workers via shared memory.
33 |
34 | The engine currently only batches data of the same modality together. This is because there are models that have different code paths for different modalities. Furthermore, due to the compute-intensive nature of multimodal encoders, it is unlikely we will scale to large batch sizes.
35 |
36 | ### Workers
37 |
38 | There is one worker (`eric.executor.worker.Worker`) process per GPU. The number of workers is the tensor parallelism degree.
39 | When spawned, the workers initialize PyTorch distributed and instantiate the model from weights downloaded from the Hugging Face Hub.
40 | It then waits for the model executor to dispatch a batch to it, runs tensor parallel inference, and dispatches tensor communication to the destination Task Executor via the [sidecar](sidecar.md).
41 |
--------------------------------------------------------------------------------
/docs/assets/css/extra.css:
--------------------------------------------------------------------------------
1 | /* Override mkdocs-video plugin styles for responsive margins */
2 | .video-container {
3 | margin: 0 auto !important;
4 | }
5 |
6 | /* Desktop screens */
7 | @media (min-width: 1024px) {
8 | .video-container {
9 | padding: 0 12% !important;
10 | }
11 | }
12 |
13 | /* Mobile - use plugin default */
14 | @media (max-width: 767px) {
15 | .video-container {
16 | padding: 0 !important;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/docs/assets/img/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/docs/assets/img/favicon.png
--------------------------------------------------------------------------------
/docs/assets/js/mathjax.js:
--------------------------------------------------------------------------------
1 | window.MathJax = {
2 | tex: {
3 | inlineMath: [["\\(", "\\)"]],
4 | displayMath: [["\\[", "\\]"]],
5 | processEscapes: true,
6 | processEnvironments: true
7 | },
8 | options: {
9 | ignoreHtmlClass: ".*|",
10 | processHtmlClass: "arithmatex"
11 | }
12 | };
13 |
14 | document$.subscribe(() => {
15 | MathJax.typesetPromise()
16 | })
17 |
--------------------------------------------------------------------------------
/docs/assets/video/cornserve.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/docs/assets/video/cornserve.mp4
--------------------------------------------------------------------------------
/docs/contributor_guide/eric.md:
--------------------------------------------------------------------------------
1 | # Eric developer guide
2 |
3 | ## Docker container
4 |
5 | All code is to be run inside a Docker container, including tests.
6 |
7 | ```bash
8 | docker build -t cornserve/eric:latest -f docker/eric.Dockerfile .
9 | docker run -it --gpus all --entrypoint bash --ipc host --rm --name eric-dev -v $PWD:/workspace/cornserve -v $HF_CACHE:/root/.cache/huggingface cornserve/eric:latest
10 | ```
11 |
12 | ## Editable installation
13 |
14 | ```bash
15 | pip install -e 'python[dev]'
16 | ```
17 |
18 | ## Testing
19 |
20 | We use `pytest`. Tests use GPUs.
21 |
22 | ```bash
23 | pytest
24 | ```
25 |
26 | Set the `CORNSERVE_TEST_DUMP_TENSOR_DIR` to an existing directory when running `pytest`.
27 | This will dump output embedding tensors to the specified directory.
28 | Refer to `build_batch` in `tests/task_executors/eric/utils.py`.
29 |
30 | ```bash
31 | export CORNERSERVE_TEST_DUMP_TENSOR_DIR=/path/to/dump
32 | pytest python/tests/task_executors/eric/models/test_llava_onevision.py::test_image_inference
33 | ```
34 |
--------------------------------------------------------------------------------
/docs/contributor_guide/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | description: Cornserve contributor guide
3 | ---
4 |
5 | # Contributor Guide
6 |
7 | Here, we provide more info for contributors.
8 | General principles are here, and child pages discuss specific topics in more detail.
9 |
10 | We have a few principles for developing Cornserve:
11 |
12 | 1. **Strict type annotations**: We enforce strict type annotation everywhere in the Python codebase, which leads to numerous benefits including better reliability, readability, and editor support. We use `pyright` for type checking.
13 | 1. **Automated testing**: We don't aim for 100% test coverage, but non-trivial and/or critical features should be tested with `pytest`.
14 |
15 | ## Contributing process
16 |
17 | !!! Important
18 | By contributing to Cornserve, you agree that your code will be licensed with Apache 2.0.
19 |
20 | If the feature is not small or requires broad changes over the codebase, please **open an issue** at our GitHub repository to discuss with us.
21 |
22 | 1. Fork our GitHub repository. Make sure you clone with `--recurse-submodules` to get the submodules.
23 | 1. Create a new Conda environment with something along the lines of `conda create -n cornserve python=3.11` and activate it with something like `conda activate cornserve`.
24 | 1. Install Cornserve in editable mode with `pip install -e 'python[dev]'`. If your environment does not have GPUs, you can use `pip install -e 'python[dev-no-gpu]'`.
25 | 1. Generate Python bindings for Protobuf files with `bash scripts/generate_pb.sh`.
26 | 1. Implement changes in your branch and add tests as needed.
27 | 1. Ensure `bash python/scripts/lint.sh` and `pytest` passes. Note that many of our tests require GPU.
28 | 1. Submit a PR to the main repository. Please ensure that CI (GitHub Actions) passes.
29 |
30 | ## Developing on Kubernetes
31 |
32 | Cornserve runs on top of Kubernetes, which introduces some complexity in development.
33 | Please refer to the guide on [Local and Distributed Development on Kubernetes](kubernetes.md) for more details.
34 |
35 | ## Documentation
36 |
37 | The documentation is written in Markdown and is located in the `docs` folder.
38 | We use MkDocs to build the documentation and use the `mkdocs-material` theme.
39 |
40 | To install documentation build dependencies:
41 |
42 | ```bash
43 | pip install -r docs/requirements.txt
44 | ```
45 |
46 | To build and preview the documentation:
47 |
48 | ```bash
49 | mkdocs serve
50 | ```
51 |
--------------------------------------------------------------------------------
/docs/contributor_guide/kubernetes.md:
--------------------------------------------------------------------------------
1 | ## Local and Distributed Development on Kubernetes
2 |
3 | ### Local development
4 |
5 | You are developing on a single node.
6 | In this case, we don't need a registry.
7 | Instead, we build containers directly within the containerd runtime of K3s.
8 |
9 | First, follow [this guide](https://blog.otvl.org/blog/k3s-loc-sp) (Section "Switching from Docker to Containerd") to set up Nerdctl and BuildKit on your local development machine.
10 |
11 | After that, you can use Nerdctl to build images directly within K3s containerd, and no pull is necessary whatsoever.
12 | Use the `build_export_images.sh` script with the `REGISTRY` environment variable set to `local` (a special case):
13 |
14 | ```bash
15 | REGISTRY=local bash scripts/build_export_images.sh
16 | ```
17 |
18 | Use the `local` overlay to deploy Cornserve:
19 |
20 | ```bash
21 | kubectl apply -k kustomize/cornserve-system/overlays/local
22 | kubectl apply -k kustomize/cornserve/overlays/local
23 | ```
24 |
25 | The `local` overlay specifies `imagePullPolicy: Never`, meaning that if the image was not found locally, it means that it was not built yet, correctly raising an error.
26 |
27 | !!! NOTE
28 | You can use the `local` overlay for the quick Minikube demo as well.
29 |
30 | ### Distributed development
31 |
32 | You are developing on a multi-node cluster.
33 |
34 | (1) Now, you do need a registry to push images to, so that remote nodes can pull them:
35 |
36 | ```bash
37 | bash kubernetes/registry.sh
38 | REGISTRY=myregisty.com:5000 bash kubernetes/set_registry.sh # (1)!
39 | ```
40 |
41 | 1. Modifies `kustomization.yaml` and `k3s/registries.yaml`
42 | If you're on this dev workflow with a *single* node cluster, you can skip `kubernetes/set_registry.sh` because things default to `localhost:5000`.
43 |
44 | (2) For K3s to work with insecure (i.e., HTTP) registries, you need to set up the `registries.yaml` file in `/etc/rancher/k3s/` on **all** nodes (master and worker) before starting K3s:
45 |
46 | ```bash
47 | sudo cp kubernetes/k3s/registries.yaml /etc/rancher/k3s/registries.yaml
48 | sudo systemctl start k3s # or k3s-agent
49 | ```
50 |
51 | (3) Build and push images to the registry using the `build_export_images.sh` script with the `REGISTRY` environment variable set to the registry address:
52 |
53 | ```bash
54 | REGISTRY=myregistry.com:5000 bash scripts/build_export_images.sh
55 | ```
56 |
57 | (4) Use the `dev` overlay (which specifies `imagePullPolicy: Always`) to deploy Cornserve:
58 |
59 | ```bash
60 | kubectl apply -k kustomize/cornserve-system/overlays/dev
61 | kubectl apply -k kustomize/cornserve/overlays/dev
62 | ```
63 |
--------------------------------------------------------------------------------
/docs/contributor_guide/sidecar.md:
--------------------------------------------------------------------------------
1 | # Sidecar developer guide
2 |
3 | ## Docker container
4 |
5 | It is recommended to run everything inside docker. Sidecar uses `UCX` as backend,
6 | so you might find the `docker/dev.Dockerfile` helpful. Additionally, Sidecar has
7 | dependency over `ucxx-cu12`, meaning you need to development on an Nvidia
8 | GPU-enabled machine at the moment.
9 |
10 | Specifying `--shm-size` with at least 4 GB and `--ipc host` is required.
11 |
12 | ## Editable installation
13 |
14 | ```bash
15 | pip install -e 'python[dev]'
16 | ```
17 |
18 | ## Testing
19 |
20 | We use pytest.
21 |
22 | ```bash
23 | pytest python/tests/services/sidecar/test_sidecar.py
24 | ```
25 |
26 | When testing locally with task executors, you can `export SIDECAR_IS_LOCAL=true` to
27 | route all communications through `localhost` instead of k8s network.
28 |
29 |
30 | ## Debugging
31 |
32 | To debug UCX related error, you can set `UCX_LOG_LEVEL=trace` and `UCXPY_LOG_LEVEL=DEBUG`
33 |
--------------------------------------------------------------------------------
/docs/contributor_guide/tracing.md:
--------------------------------------------------------------------------------
1 | # Tracing Developer guide
2 |
3 | We employ OpenTelemetry for observability. Below are some of the conventions we use.
4 |
5 | Generally, we use auto-instrumentation provided by OpenTelemetry, e.g., FastAPI, gRPC, HTTPX.
6 |
7 | ## Spans
8 |
9 | Usually named with `ClassName.function_name`.
10 |
11 | ## Attributes
12 |
13 | Usually named with `namespace.subroutine.attribute_name`.
14 | `namespace` is typically the name of the service, like `gateway`.
15 |
16 | ## Events
17 |
18 | Usually named with `action.event_name`.
19 | Use spans for things that happen over time (e.g., a subroutine), where tracking the start and end is important.
20 | On the other hand, use events for singular occurrences that happen at a specific moment in time.
21 |
22 | ## Test
23 | When testing locally, you can disable OTEL tracing through `OTEL_SDK_DISABLED=true`.
24 |
--------------------------------------------------------------------------------
/docs/getting_started/cornserve.md:
--------------------------------------------------------------------------------
1 | # Deploying Cornserve
2 |
3 | Cornserve can be deployed on a GPU cluster managed by Kubernetes.
4 |
5 | !!! Note
6 | The `cornserve` namespace is used for most of our control plane and data plane objects.
7 | On the other hand, the `cornserve-system` namespace is used for components that look over and manage the Cornserve system itself (under `cornserve`), like Jaeger and Prometheus.
8 | If you already have a Kubernetes cluster running, you can deploy Cornserve on it with the `prod` overlay:
9 |
10 | ## Deploying K3s
11 |
12 | !!! Tip
13 | If you have a Kubernetes cluster running, you can skip this section.
14 |
15 | If you don't have a Kubernetes cluster running, you can deploy Cornserve on a K3s cluster.
16 | We also use the [K3s](https://k3s.io/) distribution of Kubernetes for our development.
17 | Refer to their [Documentation](https://docs.k3s.io/quick-start/) for more details.
18 |
19 | !!! Tip
20 | If you're deploying on-premise with k3s, make sure you have plenty of disk space under `/var/lib/rancher` because `containerd` stores images there.
21 | If not, you can create a directory in a secondary storage (e.g., `/mnt/data/rancher`) and symlink it to `/var/lib/rancher` prior to starting k3s.
22 |
23 | ### NVIDIA Device Plugin
24 |
25 | The [NVIDIA GPU Device Plugin](https://github.com/NVIDIA/k8s-device-plugin) is required to expose GPUs to the Kubernetes cluster as resources.
26 | You can deploy a specific version like this:
27 |
28 | ```bash
29 | kubectl create -f https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.17.2/deployments/static/nvidia-device-plugin.yml
30 | ```
31 |
32 | ### Clone the Repository
33 |
34 | ```bash
35 | git clone https://github.com/cornserve-ai/cornserve.git
36 | cd cornserve/kubernetes
37 | ```
38 |
39 | ### Master Node
40 |
41 | Install and start K3s:
42 |
43 | ```bash
44 | curl -sfL https://get.k3s.io | INSTALL_K3S_SKIP_ENABLE=true sh -
45 | sudo mkdir -p /etc/rancher/k3s
46 | sudo cp k3s/server-config.yaml /etc/rancher/k3s/config.yaml
47 | sudo systemctl start k3s
48 | ```
49 |
50 | Note the master node address (`$MASTER_ADDRESS`) and the node token (`$NODE_TOKEN`):
51 |
52 | ```bash
53 | NODE_TOKEN="$(sudo cat /var/lib/rancher/k3s/server/node-token)"
54 | ```
55 |
56 | ### Worker Nodes
57 |
58 | Install and start K3s:
59 |
60 | ```bash
61 | curl -sfL https://get.k3s.io | K3S_URL=https://$MASTER_ADDRESS:6443 K3S_TOKEN=$NODE_TOKEN INSTALL_K3S_SKIP_ENABLE=true sh -
62 | sudo mkdir -p /etc/rancher/k3s
63 | sudo cp k3s/agent-config.yaml /etc/rancher/k3s/config.yaml
64 | sudo systemctl start k3s-agent
65 | ```
66 |
67 | ## Deploying Cornserve
68 |
69 | If you haven't already, clone the Cornserve repository:
70 |
71 | ```bash
72 | git clone git@github.com:cornserve-ai/cornserve.git
73 | cd cornserve
74 | ```
75 |
76 | On top of a Kubernetes cluster, you can deploy Cornserve with a single command:
77 |
78 | ```bash
79 | kubectl apply -k kubernetes/kustomize/cornserve-system/base
80 | kubectl apply -k kubernetes/kustomize/cornserve/overlays/prod
81 | ```
82 |
83 | !!! Note
84 | The `cornserve` namespace is used for most of our control plane and data plane objects.
85 | On the other hand, the `cornserve-system` namespace is used for components that look over and manage the Cornserve system itself (under `cornserve`), like Jaeger and Prometheus.
86 |
--------------------------------------------------------------------------------
/docs/getting_started/registering_apps.md:
--------------------------------------------------------------------------------
1 | # Deploying Apps to Cornserve and Invoking Them
2 |
3 | Once you've written your app, you can deploy it to Cornserve.
4 | The current deployment process is as follows:
5 |
6 | 1. Save the app code in a single Python file (e.g., `image_chat.py`).
7 | 2. Register & deploy the app to the Cornserve Gateway for validation and deployment:
8 | ```bash
9 | export CORNSERVE_GATEWAY_URL=[...]
10 | cornserve register image_chat.py
11 | ```
12 | 3. When validation succeeds, the Cornserve Gateway will deploy the app and all its subtasks on the Cornserve data plane, and the `cornserve` CLI invocation will return with the app's ID.
13 | 4. Finally, you can send requests to the Cornserve Gateway with your choice of HTTP client.
14 | ```python
15 | response = requests.post(
16 | f"{CORNSERVE_GATEWAY_URL}/app/invoke/{APP_ID}",
17 | json={
18 | "request_data": {
19 | "image_url": "https://example.com/image.jpg",
20 | "prompt": "Describe the image.",
21 | }
22 | },
23 | )
24 | ```
25 | Notice that what comes within the `"request_data"` key is the JSON representation of your `Request` class defined in our [previous example](building_apps.md#app).
26 |
27 | ## Next Steps
28 |
29 | To dive deeper into the architecture of Cornserve, check out our [architecture guide](../architecture/index.md).
30 |
31 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | description: "Easy, fast, and scalable multimodal AI"
3 | hide:
4 | - navigation
5 | - toc
6 | ---
7 |
8 |
9 |
Cornserve: Easy, Fast, and Scalable Multimodal AI
10 |
11 |
12 |
13 |
22 |
23 |
24 |
43 |
44 | Cornserve is an execution platform for multimodal AI.
45 | It performs **model fission** and **automatic sharing** of common components across applications on your infrastructure.
46 |
47 |
48 |
49 | - :material-vector-intersection:{ .lg .middle } **Model fission**
50 |
51 | ---
52 |
53 | Split up your complex models into smaller components and
54 | scale them independently.
55 |
56 | - :material-share-variant:{ .lg .middle } **Automatic sharing**
57 |
58 | ---
59 |
60 | Common model components are automatically shared across applications.
61 |
62 | - :material-hub:{ .lg .middle } **Multimodal-native**
63 |
64 | ---
65 |
66 | Cornserve is built multimodal-native from the ground up. Image, video,
67 | audio, and text are all first-class citizens.
68 |
69 | - :material-kubernetes:{ .lg .middle } **Deploy to K8s with one command**
70 |
71 | ---
72 |
73 | One-command deployment to Kubernetes with [Kustomize](https://kustomize.io/).
74 |
75 | - :simple-opentelemetry:{ .lg .middle } **Observability**
76 |
77 | ---
78 |
79 | Built-in support for [OpenTelemetry](https://opentelemetry.io/)
80 | to monitor your apps and requests.
81 |
82 | - :material-scale-balance:{ .lg .middle } **Open Source, Apache-2.0**
83 |
84 | ---
85 |
86 | Cornserve is open-source with the Apache 2.0 license and is available on
87 | [GitHub](https://github.com/cornserve-ai/cornserve).
88 |
89 |
90 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | mkdocs
2 | mkdocs-material[imaging]
3 | mkdocs-video
4 | mkdocs-autorefs
5 | mkdocs-jupyter
6 |
--------------------------------------------------------------------------------
/examples/mllm/app.py:
--------------------------------------------------------------------------------
1 | """An app that runs a Multimodal LLM task."""
2 |
3 | from __future__ import annotations
4 |
5 | from cornserve.app.base import AppRequest, AppResponse, AppConfig
6 | from cornserve.task.builtins.mllm import MLLMInput, MLLMTask, Modality
7 |
8 |
9 | class Request(AppRequest):
10 | """App request model.
11 |
12 | Attributes:
13 | prompt: The prompt to send to the LLM.
14 | multimodal_data: List of tuples (modality, data URL).
15 | "image", "video", etc. for modality.
16 | """
17 |
18 | prompt: str
19 | multimodal_data: list[tuple[str, str]] = []
20 |
21 |
22 | class Response(AppResponse):
23 | """App response model.
24 |
25 | Attributes:
26 | response: The response from the LLM.
27 | """
28 |
29 | response: str
30 |
31 |
32 | mllm = MLLMTask(
33 | # model_id="Qwen/Qwen2-VL-7B-Instruct",
34 | model_id="google/gemma-3-4b-it",
35 | modalities=[Modality.IMAGE],
36 | )
37 |
38 |
39 | class Config(AppConfig):
40 | """App configuration model."""
41 |
42 | tasks = {"mllm": mllm}
43 |
44 |
45 | async def serve(request: Request) -> Response:
46 | """Main serve function for the app."""
47 | mllm_input = MLLMInput(prompt=request.prompt, multimodal_data=request.multimodal_data)
48 | mllm_output = await mllm(mllm_input)
49 | return Response(response=mllm_output.response)
50 |
--------------------------------------------------------------------------------
/examples/mllm/requirements.txt:
--------------------------------------------------------------------------------
1 | requests
2 | cornserve
3 |
--------------------------------------------------------------------------------
/kubernetes/README.md:
--------------------------------------------------------------------------------
1 | # Kubernetes
2 |
3 | Please refer to our documentation for guides on [deploying](https://cornserve.ai/getting_started/cornserve/) and [developing](https://cornserve.ai/contributor_guide/kubernetes/) on Kubernetes.
4 |
--------------------------------------------------------------------------------
/kubernetes/k3s/agent-config.yaml:
--------------------------------------------------------------------------------
1 | default-runtime: "nvidia"
2 |
--------------------------------------------------------------------------------
/kubernetes/k3s/registries.yaml:
--------------------------------------------------------------------------------
1 | mirrors:
2 | "localhost:5000":
3 | endpoint:
4 | - "http://localhost:5000"
5 |
--------------------------------------------------------------------------------
/kubernetes/k3s/server-config.yaml:
--------------------------------------------------------------------------------
1 | write-kubeconfig-mode: "0644"
2 | default-runtime: "nvidia"
3 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/base/jaeger/configmap.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: ConfigMap
3 | metadata:
4 | name: jaeger-config
5 | namespace: cornserve-system
6 | data:
7 | config.yaml: |
8 | service:
9 | extensions: [jaeger_storage, jaeger_query, healthcheckv2]
10 | pipelines:
11 | traces:
12 | receivers: [otlp]
13 | processors: [batch]
14 | exporters: [jaeger_storage_exporter]
15 | telemetry:
16 | resource:
17 | service.name: jaeger
18 | metrics:
19 | level: detailed
20 | logs:
21 | level: info
22 |
23 | extensions:
24 | healthcheckv2:
25 | use_v2: true
26 | http:
27 |
28 | jaeger_query:
29 | max_clock_skew_adjust: 30s
30 | storage:
31 | traces: badger_store
32 |
33 | jaeger_storage:
34 | backends:
35 | badger_store:
36 | badger:
37 | directories:
38 | keys: "/badger/keys"
39 | values: "/badger/values"
40 | ephemeral: false
41 |
42 | receivers:
43 | otlp:
44 | protocols:
45 | grpc:
46 | endpoint: 0.0.0.0:4317
47 | http:
48 | endpoint: 0.0.0.0:4318
49 |
50 | processors:
51 | batch:
52 |
53 | exporters:
54 | jaeger_storage_exporter:
55 | trace_storage: badger_store
56 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/base/jaeger/deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: jaeger
5 | namespace: cornserve-system
6 | labels:
7 | app: jaeger
8 | spec:
9 | selector:
10 | matchLabels:
11 | app: jaeger
12 | replicas: 1
13 | template:
14 | metadata:
15 | labels:
16 | app: jaeger
17 | spec:
18 | nodeSelector:
19 | node-role.kubernetes.io/control-plane: "true"
20 | containers:
21 | - name: jaeger
22 | image: jaegertracing/jaeger:2.4.0
23 | imagePullPolicy: IfNotPresent
24 | securityContext:
25 | runAsUser: 0
26 | runAsGroup: 0
27 | args:
28 | - "--config=/config/config.yaml"
29 | ports:
30 | - containerPort: 16686
31 | name: query
32 | - containerPort: 4317
33 | name: otlp-grpc
34 | - containerPort: 4318
35 | name: otlp-http
36 | volumeMounts:
37 | - name: config-volume
38 | mountPath: /config
39 | resources:
40 | limits:
41 | cpu: 500m
42 | memory: 1Gi
43 | requests:
44 | cpu: 100m
45 | memory: 200Mi
46 | volumes:
47 | - name: config-volume
48 | configMap:
49 | name: jaeger-config
50 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/base/jaeger/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - configmap.yaml
5 | - deployment.yaml
6 | - service.yaml
7 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/base/jaeger/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: jaeger-collector
5 | namespace: cornserve-system
6 | labels:
7 | app: jaeger
8 | spec:
9 | selector:
10 | app: jaeger
11 | ports:
12 | - port: 4317
13 | targetPort: 4317
14 | name: otlp-grpc
15 | - port: 4318
16 | targetPort: 4318
17 | name: otlp-http
18 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/base/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - namespace.yaml
5 | - jaeger
6 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/base/namespace.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Namespace
3 | metadata:
4 | creationTimestamp: null
5 | name: cornserve-system
6 | spec: {}
7 | status: {}
8 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/dev/jaeger/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1alpha1
2 | kind: Component
3 | resources:
4 | - node-port-service.yaml
5 | patches:
6 | - path: volume-patch.yaml
7 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/dev/jaeger/node-port-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: jaeger-query
5 | namespace: cornserve-system
6 | labels:
7 | app: jaeger
8 | spec:
9 | selector:
10 | app: jaeger
11 | type: NodePort
12 | ports:
13 | - port: 16686
14 | targetPort: 16686
15 | nodePort: 30686
16 | name: query
17 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/dev/jaeger/volume-patch.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: jaeger
5 | namespace: cornserve-system
6 | spec:
7 | template:
8 | spec:
9 | containers:
10 | - name: jaeger
11 | volumeMounts:
12 | - name: host-storage
13 | mountPath: /badger
14 | volumes:
15 | - name: host-storage
16 | hostPath:
17 | path: /data/cornserve/jaeger-badger
18 | type: DirectoryOrCreate
19 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/dev/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - ../../base
5 | components:
6 | - jaeger
7 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/local/jaeger/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1alpha1
2 | kind: Component
3 | resources:
4 | - node-port-service.yaml
5 | patches:
6 | - path: volume-patch.yaml
7 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/local/jaeger/node-port-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: jaeger-query
5 | namespace: cornserve-system
6 | labels:
7 | app: jaeger
8 | spec:
9 | selector:
10 | app: jaeger
11 | type: NodePort
12 | ports:
13 | - port: 16686
14 | targetPort: 16686
15 | nodePort: 30686
16 | name: query
17 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/local/jaeger/volume-patch.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: jaeger
5 | namespace: cornserve-system
6 | spec:
7 | template:
8 | spec:
9 | containers:
10 | - name: jaeger
11 | volumeMounts:
12 | - name: host-storage
13 | mountPath: /badger
14 | volumes:
15 | - name: host-storage
16 | hostPath:
17 | path: /data/cornserve/jaeger-badger
18 | type: DirectoryOrCreate
19 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/local/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - ../../base
5 | components:
6 | - jaeger
7 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/minikube/jaeger/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1alpha1
2 | kind: Component
3 | resources:
4 | - node-port-service.yaml
5 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/minikube/jaeger/node-port-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: jaeger-query
5 | namespace: cornserve-system
6 | labels:
7 | app: jaeger
8 | spec:
9 | selector:
10 | app: jaeger
11 | type: NodePort
12 | ports:
13 | - port: 16686
14 | targetPort: 16686
15 | nodePort: 30686
16 | name: query
17 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve-system/overlays/minikube/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - ../../base
5 | patches:
6 | - target:
7 | kind: Deployment
8 | labelSelector: app=jaeger
9 | patch: |-
10 | - op: remove
11 | path: /spec/template/spec/nodeSelector
12 | components:
13 | - jaeger
14 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/gateway/deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: gateway
5 | namespace: cornserve
6 | labels:
7 | app: gateway
8 | spec:
9 | selector:
10 | matchLabels:
11 | app: gateway
12 | replicas: 1
13 | template:
14 | metadata:
15 | labels:
16 | app: gateway
17 | spec:
18 | nodeSelector:
19 | node-role.kubernetes.io/control-plane: "true"
20 | containers:
21 | - name: gateway
22 | image: cornserve/gateway:latest
23 | imagePullPolicy: IfNotPresent
24 | ports:
25 | - containerPort: 8000
26 | name: http
27 | envFrom:
28 | - configMapRef:
29 | name: cornserve-config
30 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/gateway/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - deployment.yaml
5 | - service.yaml
6 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/gateway/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: gateway
5 | namespace: cornserve
6 | spec:
7 | selector:
8 | app: gateway
9 | ports:
10 | - name: gateway
11 | port: 8000
12 | targetPort: http
13 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - gateway
5 | - resource-manager
6 | - sidecar
7 | - task-dispatcher
8 | - task-manager
9 | - namespace.yaml
10 | images:
11 | - name: cornserve/gateway
12 | newName: cornserve/gateway
13 | - name: cornserve/resource-manager
14 | newName: cornserve/resource-manager
15 | - name: cornserve/sidecar
16 | newName: cornserve/sidecar
17 | - name: cornserve/task-dispatcher
18 | newName: cornserve/task-dispatcher
19 | - name: cornserve/task-manager
20 | newName: cornserve/task-manager
21 | configMapGenerator:
22 | - name: cornserve-config
23 | namespace: cornserve
24 | literals:
25 | - CORNSERVE_IMAGE_PREFIX=docker.io/cornserve
26 | - CORNSERVE_IMAGE_TAG=latest
27 | - CORNSERVE_IMAGE_PULL_POLICY=IfNotPresent
28 | generatorOptions:
29 | disableNameSuffixHash: true
30 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/namespace.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Namespace
3 | metadata:
4 | creationTimestamp: null
5 | name: cornserve
6 | spec: {}
7 | status: {}
8 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/resource-manager/deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: resource-manager
5 | namespace: cornserve
6 | labels:
7 | app: resource-manager
8 | spec:
9 | selector:
10 | matchLabels:
11 | app: resource-manager
12 | replicas: 1
13 | template:
14 | metadata:
15 | labels:
16 | app: resource-manager
17 | spec:
18 | nodeSelector:
19 | node-role.kubernetes.io/control-plane: "true"
20 | serviceAccountName: resource-manager
21 | containers:
22 | - name: resource-manager
23 | image: cornserve/resource-manager:latest
24 | imagePullPolicy: IfNotPresent
25 | ports:
26 | - containerPort: 50051
27 | name: grpc
28 | envFrom:
29 | - configMapRef:
30 | name: cornserve-config
31 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/resource-manager/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - deployment.yaml
5 | - role-binding.yaml
6 | - role.yaml
7 | - service-account.yaml
8 | - service.yaml
9 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/resource-manager/role-binding.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: ClusterRoleBinding
3 | metadata:
4 | name: resource-manager-binding
5 | subjects:
6 | - kind: ServiceAccount
7 | name: resource-manager
8 | namespace: cornserve
9 | roleRef:
10 | kind: ClusterRole
11 | name: resource-manager-role
12 | apiGroup: rbac.authorization.k8s.io
13 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/resource-manager/role.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: ClusterRole
3 | metadata:
4 | name: resource-manager-role
5 | rules:
6 | - apiGroups: [""]
7 | resources: ["pods", "services"]
8 | verbs: ["*"]
9 | - apiGroups: [""]
10 | resources: ["nodes"]
11 | verbs: ["list"]
12 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/resource-manager/service-account.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: ServiceAccount
3 | metadata:
4 | name: resource-manager
5 | namespace: cornserve
6 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/resource-manager/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: resource-manager
5 | namespace: cornserve
6 | spec:
7 | selector:
8 | app: resource-manager
9 | ports:
10 | - name: resource-manager
11 | port: 50051
12 | targetPort: grpc
13 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/sidecar/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - service.yaml
5 | - role-binding.yaml
6 | - role.yaml
7 | - service-account.yaml
8 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/sidecar/role-binding.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: RoleBinding
3 | metadata:
4 | name: sidecar-binding
5 | namespace: cornserve
6 | subjects:
7 | - kind: ServiceAccount
8 | name: sidecar
9 | roleRef:
10 | kind: Role
11 | name: sidecar-role
12 | apiGroup: rbac.authorization.k8s.io
13 |
14 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/sidecar/role.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: Role
3 | metadata:
4 | name: sidecar-role
5 | namespace: cornserve
6 | rules:
7 | - apiGroups: [""]
8 | resources: ["pods"]
9 | verbs: ["list", "get"]
10 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/sidecar/service-account.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: ServiceAccount
3 | metadata:
4 | name: sidecar
5 | namespace: cornserve
6 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/sidecar/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: sidecar
5 | namespace: cornserve
6 | labels:
7 | app: sidecar
8 | spec:
9 | clusterIP: None
10 | selector:
11 | app: sidecar
12 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/sidecar/statefulset.yaml:
--------------------------------------------------------------------------------
1 | # This file is not used by Kustomize; sidecars are deployed by the Resource Manager.
2 | # We keep it here for reference. We might want to deploy only the sidecars, for instance.
3 | apiVersion: apps/v1
4 | kind: StatefulSet
5 | metadata:
6 | name: sidecar
7 | namespace: cornserve
8 | labels:
9 | app: sidecar
10 | spec:
11 | serviceName: sidecar
12 | selector:
13 | matchLabels:
14 | app: sidecar
15 | replicas: 4
16 | template:
17 | metadata:
18 | labels:
19 | app: sidecar
20 | spec:
21 | hostPID: true
22 | hostIPC: true
23 | topologySpreadConstraints:
24 | - maxSkew: 1
25 | topologyKey: "kubernetes.io/hostname"
26 | whenUnsatisfiable: DoNotSchedule
27 | labelSelector:
28 | matchLabels:
29 | app: sidecar
30 | runtimeClassName: nvidia
31 | serviceAccountName: sidecar
32 | containers:
33 | - name: sidecar
34 | image: cornserve/sidecar:latest
35 | imagePullPolicy: IfNotPresent
36 | securityContext:
37 | privileged: true
38 | env:
39 | - name: SIDECAR_WORLD_SIZE
40 | value: "4"
41 | - name: SIDECAR_POD_NAME
42 | valueFrom:
43 | fieldRef:
44 | fieldPath: metadata.name
45 | envFrom:
46 | - configMapRef:
47 | name: cornserve-config
48 | volumeMounts:
49 | - name: shm-volume
50 | mountPath: /dev/shm
51 | - name: infiniband-class
52 | mountPath: /sys/class/infiniband
53 | - name: infiniband-dev
54 | mountPath: /dev/infiniband
55 | volumes:
56 | - name: shm-volume
57 | hostPath:
58 | path: /dev/shm
59 | type: Directory
60 | - name: infiniband-class
61 | hostPath:
62 | path: /sys/class/infiniband
63 | type: Directory
64 | - name: infiniband-dev
65 | hostPath:
66 | path: /dev/infiniband
67 | type: Directory
68 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/task-dispatcher/deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: task-dispatcher
5 | namespace: cornserve
6 | labels:
7 | app: task-dispatcher
8 | spec:
9 | selector:
10 | matchLabels:
11 | app: task-dispatcher
12 | replicas: 1
13 | template:
14 | metadata:
15 | labels:
16 | app: task-dispatcher
17 | spec:
18 | nodeSelector:
19 | node-role.kubernetes.io/control-plane: "true"
20 | containers:
21 | - name: task-dispatcher
22 | image: cornserve/task-dispatcher:latest
23 | imagePullPolicy: IfNotPresent
24 | ports:
25 | - containerPort: 50051
26 | name: grpc
27 | - containerPort: 8000
28 | name: http
29 | envFrom:
30 | - configMapRef:
31 | name: cornserve-config
32 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/task-dispatcher/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - deployment.yaml
5 | - service.yaml
6 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/task-dispatcher/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: task-dispatcher
5 | namespace: cornserve
6 | spec:
7 | selector:
8 | app: task-dispatcher
9 | ports:
10 | - name: grpc
11 | port: 50051
12 | targetPort: grpc
13 | - name: http
14 | port: 8000
15 | targetPort: http
16 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/task-manager/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - role-binding.yaml
5 | - role.yaml
6 | - service-account.yaml
7 |
8 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/task-manager/role-binding.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: RoleBinding
3 | metadata:
4 | name: task-manager-binding
5 | namespace: cornserve
6 | subjects:
7 | - kind: ServiceAccount
8 | name: task-manager
9 | roleRef:
10 | kind: Role
11 | name: task-manager-role
12 | apiGroup: rbac.authorization.k8s.io
13 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/task-manager/role.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: Role
3 | metadata:
4 | name: task-manager-role
5 | namespace: cornserve
6 | rules:
7 | - apiGroups: [""]
8 | resources: ["pods", "services"]
9 | verbs: ["*"]
10 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/base/task-manager/service-account.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: ServiceAccount
3 | metadata:
4 | name: task-manager
5 | namespace: cornserve
6 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/dev/gateway/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - node-port-service.yaml
5 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/dev/gateway/node-port-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: gateway-node-port
5 | namespace: cornserve
6 | spec:
7 | type: NodePort
8 | selector:
9 | app: gateway
10 | ports:
11 | - name: gateway
12 | port: 8000
13 | targetPort: http
14 | nodePort: 30080
15 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/dev/image-pull-policy.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: gateway
5 | namespace: cornserve
6 | spec:
7 | template:
8 | spec:
9 | containers:
10 | - name: gateway
11 | imagePullPolicy: Always
12 | ---
13 | apiVersion: apps/v1
14 | kind: Deployment
15 | metadata:
16 | name: resource-manager
17 | namespace: cornserve
18 | spec:
19 | template:
20 | spec:
21 | containers:
22 | - name: resource-manager
23 | imagePullPolicy: Always
24 | ---
25 | apiVersion: apps/v1
26 | kind: Deployment
27 | metadata:
28 | name: task-dispatcher
29 | namespace: cornserve
30 | spec:
31 | template:
32 | spec:
33 | containers:
34 | - name: task-dispatcher
35 | imagePullPolicy: Always
36 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/dev/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - ../../base
5 | - gateway
6 | patches:
7 | - path: image-pull-policy.yaml
8 | images:
9 | - name: cornserve/gateway
10 | newName: localhost:5000/cornserve/gateway
11 | newTag: latest
12 | - name: cornserve/resource-manager
13 | newName: localhost:5000/cornserve/resource-manager
14 | newTag: latest
15 | - name: cornserve/sidecar
16 | newName: localhost:5000/cornserve/sidecar
17 | newTag: latest
18 | - name: cornserve/task-dispatcher
19 | newName: localhost:5000/cornserve/task-dispatcher
20 | newTag: latest
21 | - name: cornserve/task-manager
22 | newName: localhost:5000/cornserve/task-manager
23 | newTag: latest
24 | configMapGenerator:
25 | - name: cornserve-config
26 | namespace: cornserve
27 | behavior: merge
28 | literals:
29 | - CORNSERVE_IMAGE_PREFIX=localhost:5000/cornserve
30 | - CORNSERVE_IMAGE_PULL_POLICY=Always
31 | - CORNSERVE_IMAGE_TAG=latest
32 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/local/gateway/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - node-port-service.yaml
5 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/local/gateway/node-port-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: gateway-node-port
5 | namespace: cornserve
6 | spec:
7 | type: NodePort
8 | selector:
9 | app: gateway
10 | ports:
11 | - name: gateway
12 | port: 8000
13 | targetPort: http
14 | nodePort: 30080
15 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/local/image-pull-policy.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: gateway
5 | namespace: cornserve
6 | spec:
7 | template:
8 | spec:
9 | containers:
10 | - name: gateway
11 | imagePullPolicy: Never
12 | ---
13 | apiVersion: apps/v1
14 | kind: Deployment
15 | metadata:
16 | name: resource-manager
17 | namespace: cornserve
18 | spec:
19 | template:
20 | spec:
21 | containers:
22 | - name: resource-manager
23 | imagePullPolicy: Never
24 | ---
25 | apiVersion: apps/v1
26 | kind: Deployment
27 | metadata:
28 | name: task-dispatcher
29 | namespace: cornserve
30 | spec:
31 | template:
32 | spec:
33 | containers:
34 | - name: task-dispatcher
35 | imagePullPolicy: Never
36 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/local/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - ../../base
5 | - gateway
6 | patches:
7 | - path: image-pull-policy.yaml
8 | images:
9 | - name: cornserve/gateway
10 | newName: docker.io/cornserve/gateway
11 | newTag: latest
12 | - name: cornserve/resource-manager
13 | newName: docker.io/cornserve/resource-manager
14 | newTag: latest
15 | - name: cornserve/sidecar
16 | newName: docker.io/cornserve/sidecar
17 | newTag: latest
18 | - name: cornserve/task-dispatcher
19 | newName: docker.io/cornserve/task-dispatcher
20 | newTag: latest
21 | - name: cornserve/task-manager
22 | newName: docker.io/cornserve/task-manager
23 | newTag: latest
24 | configMapGenerator:
25 | - name: cornserve-config
26 | namespace: cornserve
27 | behavior: merge
28 | literals:
29 | - CORNSERVE_IMAGE_PREFIX=docker.io/cornserve
30 | - CORNSERVE_IMAGE_PULL_POLICY=Never
31 | - CORNSERVE_IMAGE_TAG=latest
32 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/minikube/gateway/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - node-port-service.yaml
5 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/minikube/gateway/node-port-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: gateway-node-port
5 | namespace: cornserve
6 | spec:
7 | type: NodePort
8 | selector:
9 | app: gateway
10 | ports:
11 | - name: gateway
12 | port: 8000
13 | targetPort: http
14 | nodePort: 30080
15 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/minikube/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - ../../base
5 | - gateway
6 | patches:
7 | - target:
8 | kind: Deployment
9 | patch: |-
10 | - op: remove
11 | path: /spec/template/spec/nodeSelector
12 | images:
13 | - name: cornserve/gateway
14 | newName: cornserve/gateway
15 | newTag: v0.0.1.post2
16 | - name: cornserve/resource-manager
17 | newName: cornserve/resource-manager
18 | newTag: v0.0.1.post2
19 | - name: cornserve/sidecar
20 | newName: cornserve/sidecar
21 | newTag: v0.0.1.post2
22 | - name: cornserve/task-dispatcher
23 | newName: cornserve/task-dispatcher
24 | newTag: v0.0.1.post2
25 | - name: cornserve/task-manager
26 | newName: cornserve/task-manager
27 | newTag: v0.0.1.post2
28 | configMapGenerator:
29 | - name: cornserve-config
30 | namespace: cornserve
31 | behavior: merge
32 | literals:
33 | - CORNSERVE_IMAGE_PREFIX=cornserve
34 | - CORNSERVE_IMAGE_PULL_POLICY=IfNotPresent
35 | - CORNSERVE_IMAGE_TAG=v0.0.1.post2
36 |
--------------------------------------------------------------------------------
/kubernetes/kustomize/cornserve/overlays/prod/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - gateway
5 | - resource-manager
6 | - sidecar
7 | - task-dispatcher
8 | - task-manager
9 | - namespace.yaml
10 | images:
11 | - name: cornserve/gateway
12 | newName: docker.io/cornserve/gateway
13 | newTag: v0.0.1.post2
14 | - name: cornserve/resource-manager
15 | newName: docker.io/cornserve/resource-manager
16 | newTag: v0.0.1.post2
17 | - name: cornserve/sidecar
18 | newName: docker.io/cornserve/sidecar
19 | newTag: v0.0.1.post2
20 | - name: cornserve/task-dispatcher
21 | newName: docker.io/cornserve/task-dispatcher
22 | newTag: v0.0.1.post2
23 | - name: cornserve/task-manager
24 | newName: docker.io/cornserve/task-manager
25 | newTag: v0.0.1.post2
26 | configMapGenerator:
27 | - name: cornserve-config
28 | namespace: cornserve
29 | behavior: merge
30 | literals:
31 | - CORNSERVE_IMAGE_PREFIX=docker.io/cornserve
32 | - CORNSERVE_IMAGE_PULL_POLICY=IfNotPresent
33 | - CORNSERVE_IMAGE_TAG=v0.0.1.post2
34 |
--------------------------------------------------------------------------------
/kubernetes/registry.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This script spins up a Docker registry container.
4 | # You potentially want to tweak the volume path to match your setup.
5 |
6 | set -evo pipefail
7 |
8 | docker run -d \
9 | -p 5000:5000 \
10 | -e REGISTRY_STORAGE_DELETE_ENABLED=true \
11 | --restart=always \
12 | --name cornserve-registry \
13 | --volume /data/cornserve/registry:/var/lib/registry \
14 | registry:2
15 |
--------------------------------------------------------------------------------
/kubernetes/set_registry.sh:
--------------------------------------------------------------------------------
1 | #!/usr/local/env bash
2 |
3 | # Usage:
4 | #
5 | # REGISTRY= bash kubernetes/set_registry.sh
6 | #
7 | # Update the private registry URL in the cornserve dev overlay kustomization file.
8 | # Override `KUSTOMIZATION_FILE` to change the file path.
9 |
10 | set -evo pipefail
11 |
12 | # Ensure the environment variable REGISTRY is set
13 | if [[ -z "${REGISTRY}" ]]; then
14 | echo "Error: REGISTRY environment variable is not set."
15 | exit 1
16 | fi
17 |
18 | KUSTOMIZATION_FILE=${KUSTOMIZATION_FILE:-"kubernetes/kustomize/cornserve/overlays/dev/kustomization.yaml"}
19 | if [[ ! -f "$KUSTOMIZATION_FILE" ]]; then
20 | echo "Error: Kustomization file '$KUSTOMIZATION_FILE' does not exist."
21 | exit 1
22 | fi
23 |
24 | K3S_REGISTRIES_FILE=${K3S_REGISTRIES_FILE:-"kubernetes/k3s/registries.yaml"}
25 | if [[ ! -f "$K3S_REGISTRIES_FILE" ]]; then
26 | echo "Error: K3s registries file '$K3S_REGISTRIES_FILE' does not exist."
27 | exit 1
28 | fi
29 |
30 | echo "Editing registry in $KUSTOMIZATION_FILE and $K3S_REGISTRIES_FILE to $REGISTRY"
31 |
32 | sed -i.bak -e "s#localhost:5000#$REGISTRY#g" "$KUSTOMIZATION_FILE"
33 | rm "$KUSTOMIZATION_FILE.bak"
34 |
35 | sed -i.bak -e "s#localhost:5000#$REGISTRY#g" "$K3S_REGISTRIES_FILE"
36 | rm "$K3S_REGISTRIES_FILE.bak"
37 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | # Project information
2 | site_name: Cornserve
3 | site_url: https://cornserve.ai
4 | site_author: Cornserve team
5 | site_description: Easy, fast, and scalable multimodal agentic AI
6 | edit_uri: ""
7 |
8 | # Repository
9 | repo_name: cornserve-ai/cornserve
10 | repo_url: https://github.com/cornserve-ai/cornserve
11 |
12 | # Copyright
13 | copyright: Copyright © 2025 Cornserve team
14 |
15 | # Theme configuration
16 | theme:
17 | name: material
18 | favicon: assets/img/favicon.png
19 | icon:
20 | repo: fontawesome/brands/github
21 | logo: material/rocket-launch-outline
22 | features:
23 | - content.code.copy
24 | - content.code.annotate
25 | - search.suggest
26 | - navigation.tabs
27 | - navigation.tabs.sticky
28 | - navigation.top
29 | - navigation.indexes
30 | - content.tooltips
31 | - announce.dismiss
32 | palette:
33 | - scheme: light
34 | primary: black
35 | accent: amber
36 |
37 | # MkDocs plugins
38 | plugins:
39 | - search
40 | - autorefs
41 | - social:
42 | enabled: !ENV [BUILD_SOCIAL_CARD, false]
43 | cards_dir: assets/img/social
44 | - mkdocs-video:
45 | is_video: True
46 | video_autoplay: True
47 | css_style:
48 | width: 60%
49 | - mkdocs-jupyter
50 |
51 | # Extensions
52 | markdown_extensions:
53 | - meta
54 | - abbr
55 | - admonition
56 | - attr_list
57 | - footnotes
58 | - md_in_html
59 | - pymdownx.superfences
60 | - pymdownx.snippets
61 | - pymdownx.details
62 | - pymdownx.critic
63 | - pymdownx.arithmatex:
64 | generic: true
65 | - pymdownx.emoji:
66 | emoji_index: !!python/name:material.extensions.emoji.twemoji
67 | emoji_generator: !!python/name:material.extensions.emoji.to_svg
68 | - pymdownx.superfences:
69 | custom_fences:
70 | - name: mermaid
71 | class: mermaid
72 | format: !!python/name:pymdownx.superfences.fence_code_format
73 | - pymdownx.highlight
74 | - pymdownx.inlinehilite
75 |
76 | # Page tree
77 | nav:
78 | - Cornserve: index.md
79 | - Getting Started:
80 | - getting_started/index.md
81 | - Deploying Cornserve: getting_started/cornserve.md
82 | - Building Apps: getting_started/building_apps.md
83 | - Using Jupyter Notebook: getting_started/jupyter.ipynb
84 | - Registering and Invoking Apps: getting_started/registering_apps.md
85 | - Architecture:
86 | - architecture/index.md
87 | - Task: architecture/task.md
88 | - Sidecar: architecture/sidecar.md
89 | - Eric: architecture/eric.md
90 | - Contributor Guide:
91 | - contributor_guide/index.md
92 | - Developing on Kubernetes: contributor_guide/kubernetes.md
93 | - Eric: contributor_guide/eric.md
94 | - Sidecar: contributor_guide/sidecar.md
95 | - Tracing: contributor_guide/tracing.md
96 |
97 | # Exclude file list
98 | exclude_docs: |
99 | requirements.txt
100 |
101 | # For Mathjax
102 | extra_javascript:
103 | - assets/js/mathjax.js
104 | - https://polyfill.io/v3/polyfill.min.js?features=es6
105 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
106 |
107 | # Extra stuff
108 | extra:
109 | analytics:
110 | provider: !ENV SITE_ANALYTICS
111 | property: G-8YY3G9ZZW5
112 | social:
113 | - name: Cornserve GitHub repository
114 | icon: fontawesome/brands/github
115 | link: https://github.com/cornserve-ai/cornserve
116 |
117 | extra_css:
118 | - assets/css/extra.css
119 |
--------------------------------------------------------------------------------
/proto/v1/common.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package cornserve.common;
4 |
5 | // Whether something was successful or not
6 | enum Status {
7 | STATUS_UNSPECIFIED = 0;
8 | STATUS_OK = 1;
9 | STATUS_ERROR = 2;
10 | }
11 |
12 | // Concrete task instantiated from a unit task class.
13 | message UnitTask {
14 | // UnitTask Python class name.
15 | string task_class_name = 1;
16 |
17 | // JSON-serialized task object.
18 | string task_config = 2;
19 | }
20 |
--------------------------------------------------------------------------------
/proto/v1/resource_manager.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package cornserve.resource_manager;
4 |
5 | import "common.proto";
6 |
7 | service ResourceManager {
8 | // Deploy a new unit task
9 | rpc DeployUnitTask(DeployUnitTaskRequest) returns (DeployUnitTaskResponse);
10 |
11 | // Tear down a unit task
12 | rpc TeardownUnitTask(TeardownUnitTaskRequest) returns (TeardownUnitTaskResponse);
13 |
14 | // Health checking
15 | rpc Healthcheck(HealthcheckRequest) returns (HealthcheckResponse);
16 | }
17 |
18 | // Deploy a new task
19 | message DeployUnitTaskRequest {
20 | // Task to deploy
21 | common.UnitTask task = 1;
22 | }
23 |
24 | message DeployUnitTaskResponse {
25 | // Success or failure
26 | common.Status status = 1;
27 | }
28 |
29 | // Tear down a task
30 | message TeardownUnitTaskRequest {
31 | // Task to tear down
32 | common.UnitTask task = 1;
33 | }
34 |
35 | message TeardownUnitTaskResponse {
36 | // Success or failure
37 | common.Status status = 1;
38 | }
39 |
40 | // Health checking
41 | message TaskManagerStatus {
42 | common.UnitTask task = 1;
43 | common.Status status = 2;
44 | }
45 |
46 | message HealthcheckRequest {}
47 |
48 | message HealthcheckResponse {
49 | common.Status status = 1;
50 | repeated TaskManagerStatus task_manager_statuses = 2;
51 | }
52 |
--------------------------------------------------------------------------------
/proto/v1/sidecar.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package cornserve.sidecar;
4 |
5 | import "common.proto";
6 |
7 | service Sidecar {
8 | rpc Register(RegisterRequest) returns (RegisterResponse);
9 | rpc Send(SendRequest) returns (SendResponse);
10 | rpc Receive(ReceiveRequest) returns (ReceiveResponse);
11 | rpc MarkDone(MarkDoneRequest) returns (MarkDoneResponse);
12 | rpc Unlink(UnlinkRequest) returns (UnlinkResponse);
13 | rpc PrepareReceive(PrepareReceiveRequest) returns (PrepareReceiveResponse);
14 |
15 | rpc CheckHealth(CheckHealthRequest) returns (CheckHealthResponse);
16 | // TODO: add unregister
17 | }
18 |
19 | message RegisterRequest {
20 | int32 rank = 1;
21 | repeated int32 group = 2;
22 | string dtype = 3;
23 | int32 send_slot_numel = 4;
24 | int32 recv_slot_numel = 5;
25 | bool concurrent_copy = 6;
26 | }
27 |
28 | message RegisterResponse {
29 | common.Status status = 1;
30 | int64 shm_size = 2; // numel in the single sender/receiver slab
31 | int32 local_rank = 3; // the GPU index to use
32 | int32 num_local_sidecars = 4; // used for init_shmem
33 | }
34 |
35 | message RankGroup {
36 | repeated int32 ranks = 1;
37 | }
38 |
39 | message SendRequest {
40 | string id = 1;
41 | repeated RankGroup dst_ranks = 2;
42 | int32 shard_rank = 3; // tp rank
43 | bytes data = 4; // serialized obj
44 | int32 chunk_id = 5;
45 | int32 num_chunks = 6;
46 | }
47 |
48 | message SendResponse {
49 | common.Status status = 1;
50 | }
51 |
52 | message ReceiveRequest {
53 | string id = 1;
54 | int32 chunk_id = 2;
55 | }
56 |
57 | message ReceiveResponse {
58 | common.Status status = 1;
59 | bytes data = 2;
60 | }
61 |
62 | message MarkDoneRequest {
63 | string id = 1;
64 | int32 chunk_id = 2;
65 | int32 shard_rank = 3; // tp rank
66 | }
67 |
68 | message MarkDoneResponse {
69 | common.Status status = 1;
70 | }
71 |
72 | message UnlinkRequest {
73 | string id = 1;
74 | int32 chunk_id = 2;
75 | }
76 |
77 | message UnlinkResponse {
78 | common.Status status = 1;
79 | }
80 |
81 |
82 | message PrepareReceiveRequest {
83 | string id = 1;
84 | bytes data = 2; // msgpack encoded handle
85 | int32 src_rank = 3;
86 | int32 chunk_id = 4;
87 | int32 num_chunks = 5;
88 | }
89 |
90 | message PrepareReceiveResponse {
91 | common.Status status = 1;
92 | }
93 |
94 | enum HealthStatus {
95 | HEALTH_ALL_GOOD = 0;
96 | HEALTH_MEMORY_PRESSURE = 1;
97 | // This is for a revived or uninitialized sidecar.
98 | HEALTH_OFFLINE = 2;
99 | }
100 |
101 | message CheckHealthRequest {
102 | }
103 |
104 | message CheckHealthResponse {
105 | HealthStatus status = 1;
106 | }
107 |
108 | message ReportMemoryRequest {
109 | int32 pressure = 1;
110 | }
111 |
112 | message ReportMemoryResponse {
113 | common.Status status = 1;
114 | }
115 |
--------------------------------------------------------------------------------
/proto/v1/task_dispatcher.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package cornserve.task_dispatcher;
4 |
5 | import "common.proto";
6 |
7 | service TaskDispatcher {
8 | // New unit task deployed
9 | rpc NotifyUnitTaskDeployment(NotifyUnitTaskDeploymentRequest) returns (NotifyUnitTaskDeploymentResponse);
10 |
11 | // Existing unit task removed
12 | rpc NotifyUnitTaskTeardown(NotifyUnitTaskTeardownRequest) returns (NotifyUnitTaskTeardownResponse);
13 | }
14 |
15 | // New unit task deployed
16 | message TaskManagerDeployment {
17 | // Task manager URL
18 | string url = 1;
19 | }
20 |
21 | message NotifyUnitTaskDeploymentRequest {
22 | // Task that was deployed
23 | common.UnitTask task = 1;
24 |
25 | // Task manager deployment info
26 | TaskManagerDeployment task_manager = 2;
27 | }
28 |
29 | message NotifyUnitTaskDeploymentResponse {
30 | common.Status status = 1;
31 | }
32 |
33 | // Existing unit task removed
34 | message NotifyUnitTaskTeardownRequest {
35 | common.UnitTask task = 1;
36 | }
37 |
38 | message NotifyUnitTaskTeardownResponse {
39 | common.Status status = 1;
40 | }
41 |
--------------------------------------------------------------------------------
/proto/v1/task_manager.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package cornserve.task_manager;
4 |
5 | import "common.proto";
6 |
7 | service TaskManager {
8 | // Configure the task manager to handle a task
9 | rpc RegisterTask(RegisterTaskRequest) returns (RegisterTaskResponse);
10 |
11 | // Add or remove resources from a task manager
12 | rpc UpdateResources(UpdateResourcesRequest) returns (UpdateResourcesResponse);
13 |
14 | // Shutdown the task manager
15 | rpc Shutdown(ShutdownRequest) returns (ShutdownResponse);
16 |
17 | // Load management
18 | rpc GetTaskProfile(GetTaskProfileRequest) returns (GetTaskProfileResponse);
19 | rpc ReconcileTargetLoad(ReconcileTargetLoadRequest) returns (ReconcileTargetLoadResponse);
20 |
21 | // Request routing
22 | rpc GetRoute(GetRouteRequest) returns (GetRouteResponse);
23 |
24 | // Health checking
25 | rpc Healthcheck(HealthcheckRequest) returns (HealthcheckResponse);
26 | }
27 |
28 | enum ResourceAction {
29 | // Give more resources
30 | ADD = 0;
31 |
32 | // Take away resources
33 | REMOVE = 1;
34 | }
35 |
36 | message GPUResource {
37 | // Whether to add or remove this resource
38 | ResourceAction action = 1;
39 |
40 | // Node ID of the GPU
41 | string node_id = 2;
42 |
43 | // Global rank of the GPU
44 | int32 global_rank = 3;
45 |
46 | // Local rank of the GPU
47 | int32 local_rank = 4;
48 | }
49 |
50 | message RegisterTaskRequest {
51 | // ID of the task manager
52 | string task_manager_id = 1;
53 |
54 | // Unit task instance
55 | common.UnitTask task = 2;
56 |
57 | // Initial set of GPU resources
58 | repeated GPUResource gpus = 3;
59 | }
60 |
61 | message RegisterTaskResponse {
62 | common.Status status = 1;
63 | }
64 |
65 | // Update resources
66 | message UpdateResourcesRequest {
67 | // ID of the task manager
68 | string task_manager_id = 1;
69 |
70 | // Resources to add or remove
71 | repeated GPUResource gpus = 2;
72 | }
73 |
74 | message UpdateResourcesResponse {
75 | common.Status status = 1;
76 | }
77 |
78 | // Shutdown
79 | message ShutdownRequest {}
80 |
81 | message ShutdownResponse {
82 | common.Status status = 1;
83 | }
84 |
85 | // Load management
86 | message ReconcileTargetLoadRequest {
87 | string task_id = 1;
88 | float target_load = 2;
89 | }
90 |
91 | message ReconcileTargetLoadResponse {
92 | common.Status status = 1;
93 | string message = 2;
94 | }
95 |
96 | // Task profiling
97 | message ProfilePoint {
98 | int32 num_gpus = 1;
99 | float max_sustainable_load = 2;
100 | DeploymentConfig deployment_config = 3;
101 | }
102 |
103 | message DeploymentConfig {
104 | int32 num_replicas = 1;
105 | int32 tensor_parallel_degree = 2;
106 | int32 pipeline_parallel_degree = 3;
107 | repeated string gpu_assignments = 4;
108 | }
109 |
110 | message GetTaskProfileRequest {
111 | string task_id = 1;
112 | }
113 |
114 | message GetTaskProfileResponse {
115 | repeated ProfilePoint profile_points = 1;
116 | }
117 |
118 | // Request routing
119 | message GetRouteRequest {
120 | // ID of the request
121 | string request_id = 1;
122 |
123 | // Optional routing hint
124 | // e.g., hash of image URL, system prompt
125 | optional string routing_hint = 2;
126 | }
127 |
128 | message GetRouteResponse {
129 | // URL of the task executor to route the request to
130 | string task_executor_url = 1;
131 |
132 | // Sidecar ranks the task executor is registered with
133 | repeated int32 sidecar_ranks = 2;
134 | }
135 |
136 | // Healthcheck response
137 | message TaskExecutorStatus {
138 | common.Status status = 1;
139 | repeated int32 sidecar_ranks = 2;
140 | }
141 |
142 | message HealthcheckRequest {}
143 |
144 | message HealthcheckResponse {
145 | common.Status status = 1;
146 | map task_executor_statuses = 2;
147 | }
148 |
--------------------------------------------------------------------------------
/python/cornserve/__init__.py:
--------------------------------------------------------------------------------
1 | """Easy, fast, and scalable multimodal agentic AI."""
2 |
3 | __version__ = "0.0.1.post2"
4 |
--------------------------------------------------------------------------------
/python/cornserve/app/__init__.py:
--------------------------------------------------------------------------------
1 | """Cornserve application package."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/app/base.py:
--------------------------------------------------------------------------------
1 | """Base classes for cornserve applications."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import ClassVar
6 |
7 | from pydantic import BaseModel, ConfigDict, Field
8 |
9 | from cornserve.task.base import Task
10 |
11 |
12 | class AppRequest(BaseModel):
13 | """Base class for application requests.
14 |
15 | All user-defined request classes must inherit from this.
16 | """
17 |
18 |
19 | class AppResponse(BaseModel):
20 | """Base class for application responses.
21 |
22 | All user-defined response classes must inherit from this.
23 | """
24 |
25 |
26 | class AppConfig(BaseModel):
27 | """Base class for application configuration.
28 |
29 | All user-defined config classes must inherit from this.
30 | """
31 |
32 | tasks: ClassVar[dict[str, Task]] = Field(
33 | default_factory=dict,
34 | description="Dictionary of tasks that the app requires.",
35 | )
36 |
37 | model_config = ConfigDict(extra="forbid")
38 |
--------------------------------------------------------------------------------
/python/cornserve/constants.py:
--------------------------------------------------------------------------------
1 | """Constants used throughout Cornserve.
2 |
3 | Environment variables expected:
4 | - `CORNSERVE_IMAGE_PREFIX`: Docker image prefix (default: "docker.io/cornserve")
5 | - `CORNSERVE_IMAGE_TAG`: Docker image tag (default: "latest")
6 | - `CORNSERVE_IMAGE_PULL_POLICY`: Docker image pull policy (default: "IfNotPresent")
7 |
8 | These environment variables are set by different Kustomize overlays depending on
9 | the deployment context (e.g., local, dev, prod).
10 | """
11 |
12 | import os
13 | import warnings
14 | from typing import TYPE_CHECKING, Any
15 |
16 |
17 | def _get_env_warn_default(var_name: str, default: str) -> str:
18 | """Get environment variable with a warning if not set, returning a default value."""
19 | try:
20 | return os.environ[var_name]
21 | except KeyError:
22 | warnings.warn(
23 | f"Environment variable {var_name} not set, using default '{default}'.",
24 | stacklevel=2,
25 | )
26 | return default
27 |
28 |
29 | def _build_image_name(name: str) -> str:
30 | """Builds a full image name with prefix, tag, and pull policy."""
31 | image_prefix = _get_env_warn_default("CORNSERVE_IMAGE_PREFIX", "docker.io/cornserve").strip("/")
32 | image_tag = _get_env_warn_default("CORNSERVE_IMAGE_TAG", "latest")
33 | return f"{image_prefix}/{name}:{image_tag}"
34 |
35 |
36 | # Cache for lazy-loaded constants
37 | _lazy_cache = {}
38 |
39 | # Define which constants should be lazily loaded
40 | _LAZY_CONSTANTS = {
41 | "CONTAINER_IMAGE_TASK_MANAGER": lambda: _build_image_name("task-manager"),
42 | "CONTAINER_IMAGE_SIDECAR": lambda: _build_image_name("sidecar"),
43 | "CONTAINER_IMAGE_ERIC": lambda: _build_image_name("eric"),
44 | "CONTAINER_IMAGE_VLLM": lambda: _build_image_name("vllm"),
45 | "CONTAINER_IMAGE_PULL_POLICY": lambda: _get_env_warn_default("CORNSERVE_IMAGE_PULL_POLICY", "IfNotPresent"),
46 | }
47 |
48 |
49 | def __getattr__(name: str) -> Any:
50 | """Module-level __getattr__ for lazy loading of image-related constants."""
51 | if name in _LAZY_CONSTANTS:
52 | if name not in _lazy_cache:
53 | _lazy_cache[name] = _LAZY_CONSTANTS[name]()
54 | return _lazy_cache[name]
55 | raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
56 |
57 |
58 | # Kubernetes resources.
59 | K8S_NAMESPACE = "cornserve"
60 | K8S_CORNSERVE_CONFIG_MAP_NAME = "cornserve-config"
61 | K8S_SIDECAR_SERVICE_NAME = "sidecar"
62 | K8S_GATEWAY_SERVICE_HTTP_URL = "http://gateway:8000"
63 | K8S_TASK_DISPATCHER_HTTP_URL = "http://task-dispatcher:8000"
64 | K8S_TASK_DISPATCHER_GRPC_URL = "task-dispatcher:50051"
65 | K8S_RESOURCE_MANAGER_GRPC_URL = "resource-manager:50051"
66 | K8S_OTEL_GRPC_URL = "http://jaeger-collector.cornserve-system.svc.cluster.local:4317"
67 | K8S_TASK_EXECUTOR_SECRET_NAME = "cornserve-env"
68 | K8S_TASK_EXECUTOR_HF_TOKEN_KEY = "hf-token"
69 | K8S_TASK_EXECUTOR_HEALTHY_TIMEOUT = 20 * 60.0
70 |
71 | # Volume host paths.
72 | VOLUME_HF_CACHE = "/data/hfcache"
73 | VOLUME_SHM = "/dev/shm"
74 |
75 | # Container images name construction.
76 | if TYPE_CHECKING:
77 | CONTAINER_IMAGE_TASK_MANAGER: str
78 | CONTAINER_IMAGE_SIDECAR: str
79 | CONTAINER_IMAGE_ERIC: str
80 | CONTAINER_IMAGE_VLLM: str
81 | CONTAINER_IMAGE_PULL_POLICY: str
82 |
--------------------------------------------------------------------------------
/python/cornserve/logging.py:
--------------------------------------------------------------------------------
1 | """Logging utilities for the Cornserve project."""
2 |
3 | from __future__ import annotations
4 |
5 | import logging
6 | import os
7 | import sys
8 | from collections.abc import MutableMapping
9 | from typing import Any
10 |
11 |
12 | def get_logger(
13 | name: str, adapters: list[type[logging.LoggerAdapter]] | None = None
14 | ) -> logging.Logger | logging.LoggerAdapter:
15 | """Get a logger with the given name with some formatting configs."""
16 | # No need to reconfigure the logger if it was already created
17 | if name in logging.Logger.manager.loggerDict:
18 | return logging.getLogger(name)
19 |
20 | logger = logging.getLogger(name)
21 | logger.setLevel(os.environ.get("CORNSERVE_LOG_LEVEL", logging.INFO))
22 | formatter = logging.Formatter("%(levelname)s %(asctime)s [%(name)s:%(lineno)d] %(message)s")
23 | handler = logging.StreamHandler(sys.stderr)
24 | handler.setFormatter(formatter)
25 | logger.addHandler(handler)
26 | if adapters is not None:
27 | for adapter in adapters:
28 | logger = adapter(logger)
29 | return logger
30 |
31 |
32 | class SidcarAdapter(logging.LoggerAdapter):
33 | """Adapter that prepends 'Sidecar {rank}' to all messages."""
34 |
35 | def __init__(self, logger: logging.Logger) -> None:
36 | """Initialize the adapter with the given logger."""
37 | super().__init__(logger, {})
38 | self.sidecar_rank = int(os.environ.get("SIDECAR_RANK", "-1"))
39 | assert self.sidecar_rank >= 0, "SIDECAR_RANK or SIDECAR_POD_NAME must be set."
40 |
41 | def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple:
42 | """Prepend 'Sidecar {rank}' to the message."""
43 | return f"Sidecar {self.sidecar_rank}: {msg}", kwargs
44 |
--------------------------------------------------------------------------------
/python/cornserve/services/__init__.py:
--------------------------------------------------------------------------------
1 | """Services that comprise Cornserve."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/services/gateway/__init__.py:
--------------------------------------------------------------------------------
1 | """The gateway service is the main entry point for Cornserve.
2 |
3 | System admins can use the gateway service to register and unregister applications.
4 | Successfully registered applications will be deployed to the cluster and made available
5 | for invocation from users.
6 | """
7 |
--------------------------------------------------------------------------------
/python/cornserve/services/gateway/app/models.py:
--------------------------------------------------------------------------------
1 | """Type definitions for the App Manager."""
2 |
3 | from __future__ import annotations
4 |
5 | import enum
6 | from collections.abc import Callable, Coroutine
7 | from dataclasses import dataclass
8 | from types import ModuleType
9 |
10 | from cornserve.app.base import AppConfig, AppRequest, AppResponse
11 |
12 |
13 | class AppState(enum.StrEnum):
14 | """Possible states of a registered app."""
15 |
16 | NOT_READY = "not ready"
17 | READY = "ready"
18 |
19 |
20 | @dataclass
21 | class AppClasses:
22 | """Container for a registered app.
23 |
24 | Attributes:
25 | request_cls: The class that defines the app's request schema.
26 | response_cls: The class that defines the app's response schema.
27 | config_cls: The class that specifies the app's configuration.
28 | serve_fn: The function that implements the app's logic.
29 | """
30 |
31 | request_cls: type[AppRequest]
32 | response_cls: type[AppResponse]
33 | config_cls: type[AppConfig]
34 | serve_fn: Callable[[AppRequest], Coroutine[None, None, AppResponse]]
35 |
36 |
37 | @dataclass
38 | class AppDefinition:
39 | """Full definition of a registered app.
40 |
41 | Attributes:
42 | app_id: The ID of the app.
43 | module: The module that contains the app's code.
44 | source_code: The Python source code of the app.
45 | classes: The classes that define the app's schema and logic.
46 | """
47 |
48 | app_id: str
49 | module: ModuleType
50 | source_code: str
51 | classes: AppClasses
52 |
--------------------------------------------------------------------------------
/python/cornserve/services/gateway/entrypoint.py:
--------------------------------------------------------------------------------
1 | """Spins up the Gateway service."""
2 |
3 | from __future__ import annotations
4 |
5 | import asyncio
6 | import os
7 | import signal
8 | from typing import TYPE_CHECKING
9 |
10 | import uvicorn
11 | from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
12 | from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorClient
13 | from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
14 |
15 | from cornserve.logging import get_logger
16 | from cornserve.services.gateway.router import create_app
17 | from cornserve.tracing import configure_otel
18 |
19 | if TYPE_CHECKING:
20 | from cornserve.services.gateway.app.manager import AppManager
21 |
22 | logger = get_logger("cornserve.services.gateway.entrypoint")
23 |
24 |
25 | async def serve() -> None:
26 | """Serve the Gateway as a FastAPI app."""
27 | logger.info("Starting Gateway service")
28 |
29 | configure_otel("gateway")
30 |
31 | app = create_app()
32 | FastAPIInstrumentor.instrument_app(app)
33 | GrpcAioInstrumentorClient().instrument()
34 | HTTPXClientInstrumentor().instrument()
35 |
36 | logger.info("Available routes are:")
37 | for route in app.routes:
38 | methods = getattr(route, "methods", None)
39 | path = getattr(route, "path", None)
40 |
41 | if methods is None or path is None:
42 | continue
43 |
44 | logger.info(
45 | "%s %s",
46 | list(methods)[0] if len(methods) == 1 else "{" + ",".join(methods) + "}",
47 | path,
48 | )
49 |
50 | config = uvicorn.Config(app, host="0.0.0.0", port=8000)
51 | server = uvicorn.Server(config)
52 | app_manager: AppManager = app.state.app_manager
53 |
54 | # `TaskContext` reads this environment variable to determine the URL of the Gateway.
55 | os.environ["CORNSERVE_GATEWAY_URL"] = "http://localhost:8000"
56 |
57 | loop = asyncio.get_running_loop()
58 | server_task = loop.create_task(server.serve())
59 |
60 | def shutdown() -> None:
61 | server_task.cancel()
62 |
63 | loop.add_signal_handler(signal.SIGINT, shutdown)
64 | loop.add_signal_handler(signal.SIGTERM, shutdown)
65 |
66 | try:
67 | await server_task
68 | except asyncio.CancelledError:
69 | logger.info("Shutting down Gateway service")
70 | await app_manager.shutdown()
71 | await server.shutdown()
72 |
73 |
74 | if __name__ == "__main__":
75 | asyncio.run(serve())
76 |
--------------------------------------------------------------------------------
/python/cornserve/services/gateway/models.py:
--------------------------------------------------------------------------------
1 | """Gateway request and response models."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Any
6 |
7 | from pydantic import BaseModel
8 |
9 |
10 | class AppRegistrationRequest(BaseModel):
11 | """Request for registering a new application.
12 |
13 | Attributes:
14 | source_code: The Python source code of the application.
15 | """
16 |
17 | source_code: str
18 |
19 |
20 | class AppRegistrationResponse(BaseModel):
21 | """Response for registering a new application.
22 |
23 | Attributes:
24 | app_id: The unique identifier for the registered application.
25 | """
26 |
27 | app_id: str
28 |
29 |
30 | class AppInvocationRequest(BaseModel):
31 | """Request for invoking a registered application.
32 |
33 | Attributes:
34 | request_data: The input data for the application. Should be a valid
35 | JSON object that matches the `Request` schema of the application.
36 | """
37 |
38 | request_data: dict[str, Any]
39 |
--------------------------------------------------------------------------------
/python/cornserve/services/gateway/session.py:
--------------------------------------------------------------------------------
1 | """A session manager for the Cornserve gateway."""
2 |
3 | import asyncio
4 | import uuid
5 | from dataclasses import dataclass, field
6 |
7 | from cornserve.frontend import TaskRequest, TaskRequestVerb, TaskResponse
8 | from cornserve.logging import get_logger
9 | from cornserve.services.gateway.task_manager import TaskManager
10 | from cornserve.task.base import UnitTask
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | @dataclass
16 | class Session:
17 | """The session state.
18 |
19 | Attributes:
20 | tasks: A dictionary of tasks that are currently in use by this session.
21 | """
22 |
23 | tasks: dict[str, UnitTask] = field(default_factory=dict)
24 |
25 |
26 | class SessionManager:
27 | """Manages debug sessions for the Cornserve gateway."""
28 |
29 | def __init__(self, task_manager: TaskManager) -> None:
30 | """Initialize the session manager."""
31 | self.task_manager = task_manager
32 | self.lock = asyncio.Lock()
33 | self.sessions: dict[str, Session] = {}
34 |
35 | async def create_session(self) -> str:
36 | """Create a new session."""
37 | session_id = str(uuid.uuid4())
38 | async with self.lock:
39 | while session_id in self.sessions:
40 | session_id = str(uuid.uuid4())
41 | self.sessions[session_id] = Session()
42 | logger.info("Created session with ID: %s", session_id)
43 | return session_id
44 |
45 | async def handle_request(self, session_id: str, request: dict) -> TaskResponse:
46 | """Handle a request for a session.
47 |
48 | Args:
49 | session_id: The ID of the session.
50 | request: The request data.
51 | """
52 | async with self.lock:
53 | if session_id not in self.sessions:
54 | logger.warning("Session ID %s not found", session_id)
55 | return TaskResponse(status=404, content="Session invalid")
56 | try:
57 | task_request = TaskRequest.model_validate(request)
58 | except Exception:
59 | logger.exception("Invalid request")
60 | return TaskResponse(status=400, content="Invalid request")
61 | if task_request.verb == TaskRequestVerb.DECLARE_USED:
62 | logger.info("Declaring tasks as used: %s", task_request.task_list)
63 | await self.task_manager.declare_used(task_request.get_tasks())
64 | self.sessions[session_id].tasks.update({task.id: task for task in task_request.get_tasks()})
65 | return TaskResponse(status=200, content="Tasks declared used")
66 | elif task_request.verb == TaskRequestVerb.DECLARE_NOT_USED:
67 | logger.info("Declaring tasks as not used: %s", task_request.task_list)
68 | await self.task_manager.declare_not_used(task_request.get_tasks())
69 | for task in task_request.get_tasks():
70 | if task.id in self.sessions[session_id].tasks:
71 | del self.sessions[session_id].tasks[task.id]
72 | return TaskResponse(status=200, content="Tasks declared not used")
73 | elif task_request.verb == TaskRequestVerb.HEARTBEAT:
74 | return TaskResponse(status=200, content="Session is alive")
75 | else:
76 | logger.warning("Unknown method %s", task_request.verb)
77 | return TaskResponse(status=400, content="Unknown method")
78 |
79 | async def destroy_session(self, session_id: str) -> bool:
80 | """Destroy a session. Clean up all tasks in use by this session.
81 |
82 | Args:
83 | session_id: The ID of the session to destroy.
84 | """
85 | async with self.lock:
86 | if session_id in self.sessions:
87 | logger.info("Destroying session with ID: %s", session_id)
88 | tasks = list(self.sessions[session_id].tasks.values())
89 | await self.task_manager.declare_not_used(tasks)
90 | del self.sessions[session_id]
91 | return True
92 | return False
93 |
--------------------------------------------------------------------------------
/python/cornserve/services/resource_manager/__init__.py:
--------------------------------------------------------------------------------
1 | """The Resource Manager manages resources (GPUs) and task managers.
2 |
3 | It allocates resources to task managers and spawns and kills them.
4 | It also reconciles the state of task managers with the task dispatcher.
5 | """
6 |
--------------------------------------------------------------------------------
/python/cornserve/services/resource_manager/entrypoint.py:
--------------------------------------------------------------------------------
1 | """Entrypoint for the Resource Manager service."""
2 |
3 | from __future__ import annotations
4 |
5 | import asyncio
6 | import signal
7 |
8 | from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorClient, GrpcAioInstrumentorServer
9 |
10 | from cornserve.logging import get_logger
11 | from cornserve.services.resource_manager.grpc import create_server
12 | from cornserve.services.resource_manager.manager import ResourceManager
13 | from cornserve.tracing import configure_otel
14 |
15 | logger = get_logger("cornserve.services.resource_manager.entrypoint")
16 |
17 |
18 | async def serve() -> None:
19 | """Start the gRPC server."""
20 | configure_otel("resource_manager")
21 |
22 | GrpcAioInstrumentorServer().instrument()
23 | GrpcAioInstrumentorClient().instrument()
24 |
25 | resource_manager = await ResourceManager.init()
26 |
27 | server = create_server(resource_manager)
28 | await server.start()
29 |
30 | logger.info("gRPC server started")
31 |
32 | loop = asyncio.get_running_loop()
33 | server_task = loop.create_task(server.wait_for_termination())
34 |
35 | def shutdown() -> None:
36 | server_task.cancel()
37 |
38 | loop.add_signal_handler(signal.SIGINT, shutdown)
39 | loop.add_signal_handler(signal.SIGTERM, shutdown)
40 |
41 | try:
42 | await server_task
43 | except asyncio.CancelledError:
44 | logger.info("Shutting down Resource Manager service")
45 | await server.stop(5)
46 | logger.info("Shutting down resource manager...")
47 | await resource_manager.shutdown()
48 | logger.info("Resource Manager service shutdown complete")
49 |
50 |
51 | if __name__ == "__main__":
52 | asyncio.run(serve())
53 |
--------------------------------------------------------------------------------
/python/cornserve/services/resource_manager/grpc.py:
--------------------------------------------------------------------------------
1 | """Resource Manager gRPC server."""
2 |
3 | from __future__ import annotations
4 |
5 | import grpc
6 |
7 | from cornserve.logging import get_logger
8 | from cornserve.services.pb import common_pb2, resource_manager_pb2, resource_manager_pb2_grpc
9 | from cornserve.services.resource_manager.manager import ResourceManager
10 | from cornserve.task.base import UnitTask
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | class ResourceManagerServicer(resource_manager_pb2_grpc.ResourceManagerServicer):
16 | """Resource Manager gRPC service implementation."""
17 |
18 | def __init__(self, manager: ResourceManager) -> None:
19 | """Initialize the ResourceManagerServicer."""
20 | self.manager = manager
21 |
22 | async def DeployUnitTask(
23 | self,
24 | request: resource_manager_pb2.DeployUnitTaskRequest,
25 | context: grpc.aio.ServicerContext,
26 | ) -> resource_manager_pb2.DeployUnitTaskResponse:
27 | """Deploy a unit task in the cluster."""
28 | await self.manager.deploy_unit_task(UnitTask.from_pb(request.task))
29 | return resource_manager_pb2.DeployUnitTaskResponse(status=common_pb2.Status.STATUS_OK)
30 |
31 | async def TeardownUnitTask(
32 | self,
33 | request: resource_manager_pb2.TeardownUnitTaskRequest,
34 | context: grpc.aio.ServicerContext,
35 | ) -> resource_manager_pb2.TeardownUnitTaskResponse:
36 | """Reconcile a removed app by shutting down task managers if needed."""
37 | await self.manager.teardown_unit_task(UnitTask.from_pb(request.task))
38 | return resource_manager_pb2.TeardownUnitTaskResponse(status=common_pb2.Status.STATUS_OK)
39 |
40 | async def Healthcheck(
41 | self,
42 | request: resource_manager_pb2.HealthcheckRequest,
43 | context: grpc.aio.ServicerContext,
44 | ) -> resource_manager_pb2.HealthcheckResponse:
45 | """Recursively check and report the health of task managers."""
46 | try:
47 | overall_status, task_manager_statuses = await self.manager.healthcheck()
48 |
49 | status_map = {
50 | True: common_pb2.Status.STATUS_OK,
51 | False: common_pb2.Status.STATUS_ERROR,
52 | }
53 |
54 | # Convert the statuses into proto message format
55 | proto_statuses = [
56 | resource_manager_pb2.TaskManagerStatus(task=task.to_pb(), status=status_map[status])
57 | for task, status in task_manager_statuses
58 | ]
59 |
60 | return resource_manager_pb2.HealthcheckResponse(
61 | status=status_map[overall_status], task_manager_statuses=proto_statuses
62 | )
63 | except Exception as e:
64 | logger.exception("Healthcheck failed: %s", e)
65 | return resource_manager_pb2.HealthcheckResponse(
66 | status=common_pb2.Status.STATUS_ERROR, task_manager_statuses=[]
67 | )
68 |
69 |
70 | def create_server(resource_manager: ResourceManager) -> grpc.aio.Server:
71 | """Create the gRPC server for the Resource Manager."""
72 | servicer = ResourceManagerServicer(resource_manager)
73 | server = grpc.aio.server()
74 | resource_manager_pb2_grpc.add_ResourceManagerServicer_to_server(servicer, server)
75 | listen_addr = "[::]:50051"
76 | server.add_insecure_port(listen_addr)
77 | logger.info("Starting server on %s", listen_addr)
78 | return server
79 |
--------------------------------------------------------------------------------
/python/cornserve/services/sidecar/__init__.py:
--------------------------------------------------------------------------------
1 | """Sidecar servers that forward data between peers."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/services/sidecar/scheduler.py:
--------------------------------------------------------------------------------
1 | """Template Scheduler for Sidecar."""
2 |
3 | from __future__ import annotations
4 |
5 | import asyncio
6 | import contextlib
7 | from collections.abc import Callable
8 | from dataclasses import dataclass
9 | from types import TracebackType
10 | from typing import Any
11 |
12 |
13 | @dataclass
14 | class _Job:
15 | fn: Callable[..., Any]
16 | args: tuple
17 | kwargs: dict
18 | fut: asyncio.Future
19 |
20 |
21 | # ── async no‑op context manager ─────────────────────────────────────
22 | class _AsyncNullCM:
23 | """`async with _ASYNC_NULL:` does nothing (like contextlib.nullcontext)."""
24 |
25 | async def __aenter__(self) -> None: # noqa: D401
26 | return None
27 |
28 | async def __aexit__(
29 | self,
30 | exc_type: type[BaseException] | None,
31 | exc: BaseException | None,
32 | tb: TracebackType | None,
33 | ) -> bool:
34 | return False
35 |
36 |
37 | _ASYNC_NULL = _AsyncNullCM()
38 |
39 |
40 | class Scheduler:
41 | """Central launch‑controller."""
42 |
43 | def __init__(self, max_concurrency: int | None = None) -> None:
44 | """Initialize the scheduler."""
45 | self._q: asyncio.Queue[_Job] = asyncio.Queue()
46 | self._runner_task: asyncio.Task | None = None
47 | self._sema = asyncio.Semaphore(max_concurrency) if max_concurrency else None
48 |
49 | async def submit(self, fn: Callable[..., Any], *args, **kwargs) -> Any:
50 | """Submit a job to the queue."""
51 | loop = asyncio.get_running_loop()
52 | fut: asyncio.Future = loop.create_future()
53 | await self._q.put(_Job(fn, args, kwargs, fut))
54 | return await fut
55 |
56 | async def schedule(self) -> _Job:
57 | """Schedule the next job in the queue."""
58 | return await self._q.get()
59 |
60 | async def _runner(self) -> None:
61 | """Infinite loop to process jobs in the queue."""
62 | while True:
63 | job = await self.schedule()
64 |
65 | async def _execute(j: _Job) -> None:
66 | cm = self._sema or _ASYNC_NULL
67 | async with cm:
68 | try:
69 | res = j.fn(*j.args, **j.kwargs)
70 | if asyncio.iscoroutine(res):
71 | res = await res
72 | j.fut.set_result(res)
73 | except Exception as exc:
74 | j.fut.set_exception(exc)
75 |
76 | asyncio.create_task(_execute(job))
77 |
78 | def start(self) -> None:
79 | """Start the scheduler and begin processing jobs."""
80 | if self._runner_task is None:
81 | self._runner_task = asyncio.create_task(self._runner())
82 |
83 | async def stop(self) -> None:
84 | """Stop the scheduler and cancel all jobs in flight."""
85 | if self._runner_task:
86 | self._runner_task.cancel()
87 | with contextlib.suppress(asyncio.CancelledError):
88 | await self._runner_task
89 |
--------------------------------------------------------------------------------
/python/cornserve/services/task_dispatcher/__init__.py:
--------------------------------------------------------------------------------
1 | """The Task Dispatcher is the interface between App Drivers and the data plane.
2 |
3 | It receives requests from App Drivers and dispatches them to the appropriate
4 | task manager. It also receives updates from the resource manager about changes
5 | in Task Managers.
6 |
7 | The Task Dispatcher requires both a gRPC server and a REST HTTP server.
8 | The gRPC server is used by the Resource Manager to send updates about
9 | Task Manager. On the other hand, the REST API server is used by App Drivers
10 | to invoke tasks.
11 | """
12 |
--------------------------------------------------------------------------------
/python/cornserve/services/task_dispatcher/entrypoint.py:
--------------------------------------------------------------------------------
1 | """Spins up the Task Dispatcher service with gRPC and FastAPI."""
2 |
3 | from __future__ import annotations
4 |
5 | import asyncio
6 | import signal
7 | from typing import TYPE_CHECKING
8 |
9 | import uvicorn
10 | from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
11 | from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient, GrpcInstrumentorServer
12 | from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
13 |
14 | from cornserve.logging import get_logger
15 | from cornserve.services.task_dispatcher.grpc import create_server
16 | from cornserve.services.task_dispatcher.router import create_app
17 | from cornserve.tracing import configure_otel
18 |
19 | if TYPE_CHECKING:
20 | from cornserve.services.task_dispatcher.dispatcher import TaskDispatcher
21 |
22 | logger = get_logger("cornserve.services.task_dispatcher.entrypoint")
23 |
24 |
25 | async def serve() -> None:
26 | """Serve the Task Dispatcher service."""
27 | logger.info("Starting Gateway service")
28 |
29 | configure_otel("task_dispatcher")
30 |
31 | # FastAPI server
32 | app = create_app()
33 |
34 | FastAPIInstrumentor.instrument_app(app)
35 | HTTPXClientInstrumentor().instrument()
36 | GrpcInstrumentorClient().instrument()
37 | GrpcInstrumentorServer().instrument()
38 |
39 | logger.info("Available HTTP routes are:")
40 | for route in app.routes:
41 | methods = getattr(route, "methods", None)
42 | path = getattr(route, "path", None)
43 |
44 | if methods is None or path is None:
45 | continue
46 |
47 | logger.info(
48 | "%s %s",
49 | list(methods)[0] if len(methods) == 1 else "{" + ",".join(methods) + "}",
50 | path,
51 | )
52 |
53 | config = uvicorn.Config(app, host="0.0.0.0", port=8000)
54 | uvicorn_server = uvicorn.Server(config)
55 | dispatcher: TaskDispatcher = app.state.dispatcher
56 |
57 | # gRPC server
58 | grpc_server = create_server(dispatcher)
59 |
60 | loop = asyncio.get_running_loop()
61 | uvicorn_server_task = loop.create_task(uvicorn_server.serve())
62 | await grpc_server.start()
63 |
64 | def shutdown() -> None:
65 | uvicorn_server_task.cancel()
66 |
67 | loop.add_signal_handler(signal.SIGINT, shutdown)
68 | loop.add_signal_handler(signal.SIGTERM, shutdown)
69 |
70 | try:
71 | await uvicorn_server_task
72 | except asyncio.CancelledError:
73 | logger.info("Shutting down Task Dispatcher service")
74 | await dispatcher.shutdown()
75 | await uvicorn_server.shutdown()
76 | await grpc_server.stop(5)
77 | logger.info("Task Dispatcher service shutdown complete")
78 |
79 |
80 | if __name__ == "__main__":
81 | asyncio.run(serve())
82 |
--------------------------------------------------------------------------------
/python/cornserve/services/task_dispatcher/grpc.py:
--------------------------------------------------------------------------------
1 | """Task Dispatcher gRPC server."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | import grpc
8 |
9 | from cornserve.logging import get_logger
10 | from cornserve.services.pb import common_pb2, task_dispatcher_pb2, task_dispatcher_pb2_grpc
11 | from cornserve.task.base import UnitTask
12 |
13 | if TYPE_CHECKING:
14 | from cornserve.services.task_dispatcher.dispatcher import TaskDispatcher
15 |
16 | logger = get_logger(__name__)
17 |
18 |
19 | class TaskDispatcherServicer(task_dispatcher_pb2_grpc.TaskDispatcherServicer):
20 | """Task Dispatcher gRPC service implementation."""
21 |
22 | def __init__(self, task_dispatcher: TaskDispatcher) -> None:
23 | """Initializer the TaskDispatcherServicer."""
24 | self.task_dispatcher = task_dispatcher
25 |
26 | async def NotifyUnitTaskDeployment(
27 | self,
28 | request: task_dispatcher_pb2.NotifyUnitTaskDeploymentRequest,
29 | context: grpc.aio.ServicerContext,
30 | ) -> task_dispatcher_pb2.NotifyUnitTaskDeploymentResponse:
31 | """Register new task managers with the task dispatcher."""
32 | await self.task_dispatcher.notify_task_deployment(
33 | task=UnitTask.from_pb(request.task),
34 | task_manager_url=request.task_manager.url,
35 | )
36 | return task_dispatcher_pb2.NotifyUnitTaskDeploymentResponse(status=common_pb2.Status.STATUS_OK)
37 |
38 | async def NotifyUnitTaskTeardown(
39 | self,
40 | request: task_dispatcher_pb2.NotifyUnitTaskTeardownRequest,
41 | context: grpc.aio.ServicerContext,
42 | ) -> task_dispatcher_pb2.NotifyUnitTaskTeardownResponse:
43 | """Remove task managers from the task dispatcher."""
44 | await self.task_dispatcher.notify_task_teardown(task=UnitTask.from_pb(request.task))
45 | return task_dispatcher_pb2.NotifyUnitTaskTeardownResponse(status=common_pb2.Status.STATUS_OK)
46 |
47 |
48 | def create_server(task_dispatcher: TaskDispatcher) -> grpc.aio.Server:
49 | """Create the gRPC server for the Task Dispatcher."""
50 | servicer = TaskDispatcherServicer(task_dispatcher)
51 | server = grpc.aio.server()
52 | task_dispatcher_pb2_grpc.add_TaskDispatcherServicer_to_server(servicer, server)
53 | listen_addr = "[::]:50051"
54 | server.add_insecure_port(listen_addr)
55 | logger.info("gRPC server listening on %s", listen_addr)
56 | return server
57 |
--------------------------------------------------------------------------------
/python/cornserve/services/task_dispatcher/router.py:
--------------------------------------------------------------------------------
1 | """Task Dispatcher REST API server."""
2 |
3 | from __future__ import annotations
4 |
5 | from fastapi import APIRouter, FastAPI, Request, Response, status
6 | from opentelemetry import trace
7 |
8 | from cornserve.logging import get_logger
9 | from cornserve.services.task_dispatcher.dispatcher import TaskDispatcher
10 | from cornserve.task.base import TaskGraphDispatch
11 |
12 | router = APIRouter()
13 | logger = get_logger(__name__)
14 | tracer = trace.get_tracer(__name__)
15 |
16 |
17 | @router.post("/task")
18 | async def invoke_task(request: TaskGraphDispatch, raw_request: Request):
19 | """Invoke a task with the given request data."""
20 | logger.info("Task dispatch received: %s", request)
21 |
22 | dispatcher: TaskDispatcher = raw_request.app.state.dispatcher
23 | try:
24 | response = await dispatcher.invoke(request.invocations)
25 | return response
26 | except Exception as e:
27 | logger.exception("Error while invoking task")
28 | return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=str(e))
29 |
30 |
31 | @router.get("/health")
32 | async def health_check():
33 | """Health check endpoint."""
34 | return Response(status_code=status.HTTP_200_OK)
35 |
36 |
37 | def init_app_state(app: FastAPI) -> None:
38 | """Initialize the app state for the Task Dispatcher."""
39 | app.state.dispatcher = TaskDispatcher()
40 |
41 |
42 | def create_app() -> FastAPI:
43 | """Build the FastAPI app for the Task Dispatcher."""
44 | app = FastAPI(title="Cornserve Task Dispatcher")
45 | app.include_router(router)
46 | init_app_state(app)
47 | return app
48 |
--------------------------------------------------------------------------------
/python/cornserve/services/task_manager/__init__.py:
--------------------------------------------------------------------------------
1 | """The Task Manager manages task executors.
2 |
3 | A Task Manager handles exactly one type of task, for instance,
4 | multimodal data embedding (Eric) or LLM inference (vLLM).
5 | It spawns and kills task executors given the resource (GPUs) allocated to it by
6 | the resource manager.
7 |
8 | It's primarily responsible for
9 | 1. Spawning and killing task executors
10 | 2. Routing requests to the appropriate task executor
11 | """
12 |
--------------------------------------------------------------------------------
/python/cornserve/services/task_manager/entrypoint.py:
--------------------------------------------------------------------------------
1 | """Entrypoint for the Task Manager service."""
2 |
3 | from __future__ import annotations
4 |
5 | import asyncio
6 | import signal
7 |
8 | from cornserve.logging import get_logger
9 | from cornserve.services.task_manager.grpc import create_server
10 |
11 | logger = get_logger("cornserve.services.task_manager.entrypoint")
12 |
13 |
14 | async def serve() -> None:
15 | """Serve the Task Manager service."""
16 | logger.info("Starting Task Manager service")
17 |
18 | server, servicer = create_server()
19 | await server.start()
20 |
21 | logger.info("gRPC server started")
22 |
23 | loop = asyncio.get_running_loop()
24 | server_task = asyncio.create_task(server.wait_for_termination())
25 |
26 | def shutdown() -> None:
27 | server_task.cancel()
28 |
29 | loop.add_signal_handler(signal.SIGINT, shutdown)
30 | loop.add_signal_handler(signal.SIGTERM, shutdown)
31 |
32 | try:
33 | await server_task
34 | except asyncio.CancelledError:
35 | logger.info("Shutting down Task Manager service")
36 | await server.stop(5)
37 | if servicer.manager is not None:
38 | logger.info("Shutting down task manager...")
39 | await servicer.manager.shutdown()
40 | logger.info("Task Manager service shutdown complete")
41 |
42 |
43 | if __name__ == "__main__":
44 | asyncio.run(serve())
45 |
--------------------------------------------------------------------------------
/python/cornserve/sidecar/__init__.py:
--------------------------------------------------------------------------------
1 | """Sidecar APIs for Task Executors."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/sidecar/constants.py:
--------------------------------------------------------------------------------
1 | """Sidecar constants."""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 |
7 | from cornserve import constants
8 |
9 | RANK_OFFSET = 1000000
10 | CHUNK_OFFSET = 1000
11 | GRPC_BASE_PORT = 10000
12 | UCX_BASE_PORT = 12000
13 |
14 |
15 | def chunk_tag(id: str, rank: int, chunk_id: int, shard_rank: int) -> int:
16 | """Generate a tag for the chunk.
17 |
18 | The tag is a unique id for a chunk during transmission.
19 | """
20 | # convert the hex uuid to int
21 | base = int(id, 16)
22 | return base + RANK_OFFSET * (rank) + CHUNK_OFFSET * (chunk_id) + shard_rank
23 |
24 |
25 | def shm_filename() -> str:
26 | """Shared memory filename in each node."""
27 | return "/dev/shm/sc_shm"
28 |
29 |
30 | def grpc_url_from_rank(rank: int) -> str:
31 | """GRPC channel url from rank."""
32 | assert rank >= 0, "Rank should be non-negative"
33 | is_local = os.environ.get("SIDECAR_IS_LOCAL", "false").lower() == "true"
34 | if is_local:
35 | return f"localhost:{GRPC_BASE_PORT + rank}"
36 | parts = [
37 | f"sidecar-{rank}",
38 | constants.K8S_SIDECAR_SERVICE_NAME,
39 | constants.K8S_NAMESPACE,
40 | "svc.cluster.local",
41 | ]
42 | return ".".join(parts) + f":{GRPC_BASE_PORT + rank}"
43 |
44 |
45 | def ucx_url_from_rank(rank: int) -> str:
46 | """UCX connection host url from rank."""
47 | assert rank >= 0, "Rank should be non-negative"
48 | is_local = os.environ.get("SIDECAR_IS_LOCAL", "false").lower() == "true"
49 | if is_local:
50 | return "127.0.0.1"
51 | parts = [
52 | f"sidecar-{rank}",
53 | constants.K8S_SIDECAR_SERVICE_NAME,
54 | constants.K8S_NAMESPACE,
55 | "svc.cluster.local",
56 | ]
57 | return ".".join(parts)
58 |
59 |
60 | def ucx_port_from_rank(rank: int) -> int:
61 | """UCX connection host port from rank."""
62 | assert rank >= 0, "Rank should be non-negative"
63 | return UCX_BASE_PORT + rank
64 |
--------------------------------------------------------------------------------
/python/cornserve/sidecar/utils.py:
--------------------------------------------------------------------------------
1 | """Sidecar utility functions and constants."""
2 |
3 | from __future__ import annotations
4 |
5 | import ctypes
6 |
7 | import torch
8 |
9 | from cornserve.logging import get_logger
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | def buffer_from_tensor(t: torch.Tensor) -> ctypes.Array:
15 | """Convert a torch tensor to a ctypes buffer for ucx-py."""
16 | data_ptr = t.data_ptr()
17 | size_bytes = t.numel() * t.element_size()
18 | buffer = (ctypes.c_byte * size_bytes).from_address(data_ptr)
19 | return buffer
20 |
21 |
22 | def device_from_rank(rank: int) -> torch.device:
23 | """Torch device from rank."""
24 | assert rank >= 0, "Rank should be non-negative"
25 | return torch.device(f"cuda:{rank}")
26 |
27 |
28 | def init_shmem(
29 | filename: str,
30 | local_ranks: list[int],
31 | num_local_sidecars: int,
32 | partition_numel: int,
33 | dtype: torch.dtype,
34 | ) -> tuple[torch.Tensor, torch.Tensor]:
35 | """Initialize a shared memory buffer between the sidecar client and server.
36 |
37 | All sidecars within the same node will share the same buffer but at different offsets.
38 | Each sidecar will only access its own slice of the buffer, and each slice has the same size.
39 |
40 | Args:
41 | filename: The filename of the shared memory buffer.
42 | local_ranks: The local ranks of the sidecars that will share the buffer, must be consecutive.
43 | num_local_sidecars: Total number of sidecars within the same node.
44 | partition_numel: Number of elements of given dtype in the shared memory buffer used by each device/sidecar.
45 | dtype: Data type of the shared memory buffer.
46 | """
47 | # sanity check device_ids
48 | for i in range(len(local_ranks) - 1):
49 | assert local_ranks[i] + 1 == local_ranks[i + 1], "Device IDs must be consecutive"
50 | total_element_count = partition_numel * num_local_sidecars
51 | full_tensor = torch.from_file(
52 | filename=filename,
53 | shared=True,
54 | size=total_element_count,
55 | dtype=dtype,
56 | )
57 | start = partition_numel * local_ranks[0]
58 | end = partition_numel * (local_ranks[-1] + 1)
59 | return full_tensor, full_tensor[start:end]
60 |
--------------------------------------------------------------------------------
/python/cornserve/task/__init__.py:
--------------------------------------------------------------------------------
1 | """The Task abstraction."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/task/builtins/__init__.py:
--------------------------------------------------------------------------------
1 | """Built-in tasks."""
2 |
3 | from cornserve.task.builtins import encoder, llm
4 |
--------------------------------------------------------------------------------
/python/cornserve/task/builtins/encoder.py:
--------------------------------------------------------------------------------
1 | """Build-in task for modality encoders."""
2 |
3 | from __future__ import annotations
4 |
5 | import enum
6 |
7 | from cornserve.task.base import TaskInput, TaskOutput, UnitTask
8 | from cornserve.task.forward import DataForward, Tensor
9 |
10 |
11 | class Modality(enum.StrEnum):
12 | """Supported modalities for encoder tasks."""
13 |
14 | IMAGE = "image"
15 | VIDEO = "video"
16 | AUDIO = "audio"
17 |
18 |
19 | class EncoderInput(TaskInput):
20 | """Input model for encoder tasks.
21 |
22 | Attributes:
23 | data_urls: The URLs of the data to encode.
24 | """
25 |
26 | data_urls: list[str]
27 |
28 |
29 | class EncoderOutput(TaskOutput):
30 | """Output model for encoder tasks.
31 |
32 | Attributes:
33 | embeddings: The embeddings from the encoder.
34 | """
35 |
36 | embeddings: list[DataForward[Tensor]]
37 |
38 |
39 | class EncoderTask(UnitTask[EncoderInput, EncoderOutput]):
40 | """A task that invokes an encoder.
41 |
42 | Attributes:
43 | model_id: The ID of the model to use for the task.
44 | modality: Modality of data this encoder can embed.
45 | """
46 |
47 | model_id: str
48 | modality: Modality
49 |
50 | def make_record_output(self, task_input: EncoderInput) -> EncoderOutput:
51 | """Create a task output for task invocation recording."""
52 | return EncoderOutput(embeddings=[DataForward[Tensor]() for _ in task_input.data_urls])
53 |
54 | def make_name(self) -> str:
55 | """Create a concise string representation of the task."""
56 | return f"encoder-{self.modality.lower()}-{self.model_id.split('/')[-1].lower()}"
57 |
--------------------------------------------------------------------------------
/python/cornserve/task/builtins/llm.py:
--------------------------------------------------------------------------------
1 | """Build-in task for LLMs."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Generic, TypeVar
6 |
7 | from cornserve.task.base import TaskInput, TaskOutput, UnitTask
8 | from cornserve.task.forward import DataForward, Tensor
9 |
10 |
11 | class LLMInput(TaskInput):
12 | """Input model for LLM tasks.
13 |
14 | Attributes:
15 | prompt: The prompt to send to the LLM.
16 | multimodal_data: List of tuples (modality, data URL).
17 | "image", "audio", "video", etc. for modality.
18 | embeddings: Multimodal embeddings to send to the LLM.
19 | """
20 |
21 | prompt: str
22 | multimodal_data: list[tuple[str, str]] = []
23 | embeddings: list[DataForward[Tensor]] = []
24 |
25 |
26 | class LLMOutputBase(TaskOutput):
27 | """Base output model for LLM tasks."""
28 |
29 |
30 | InputT = TypeVar("InputT", bound=TaskInput)
31 | OutputT = TypeVar("OutputT", bound=TaskOutput)
32 |
33 |
34 | class LLMBaseTask(UnitTask[InputT, OutputT], Generic[InputT, OutputT]):
35 | """A task that invokes an LLM.
36 |
37 | Attributes:
38 | model_id: The ID of the model to use for the task.
39 | """
40 |
41 | model_id: str
42 |
43 | def make_name(self) -> str:
44 | """Create a concise string representation of the task."""
45 | return f"llm-{self.model_id.split('/')[-1].lower()}"
46 |
47 |
48 | class LLMOutput(LLMOutputBase):
49 | """Output model for LLM tasks.
50 |
51 | Attributes:
52 | response: The response from the LLM.
53 | """
54 |
55 | response: str
56 |
57 |
58 | class LLMTask(LLMBaseTask[LLMInput, LLMOutput]):
59 | """A task that invokes an LLM and returns the response.
60 |
61 | Attributes:
62 | model_id: The ID of the model to use for the task.
63 | """
64 |
65 | model_id: str
66 |
67 | def make_record_output(self, task_input: LLMInput) -> LLMOutput:
68 | """Create a task output for task invocation recording."""
69 | return LLMOutput(response="")
70 |
71 |
72 | class LLMForwardOutput(LLMOutputBase):
73 | """Output model for LLM tasks with the response forwarded.
74 |
75 | Attributes:
76 | response: The response from the LLM.
77 | """
78 |
79 | response: DataForward[str]
80 |
81 |
82 | class LLMForwardOutputTask(LLMBaseTask[LLMInput, LLMForwardOutput]):
83 | """A task that invokes an LLM and forwards the response.
84 |
85 | Attributes:
86 | model_id: The ID of the model to use for the task.
87 | """
88 |
89 | model_id: str
90 |
91 | def make_record_output(self, task_input: LLMInput) -> LLMForwardOutput:
92 | """Create a task output for task invocation recording."""
93 | return LLMForwardOutput(response=DataForward[str]())
94 |
--------------------------------------------------------------------------------
/python/cornserve/task/builtins/mllm.py:
--------------------------------------------------------------------------------
1 | """Built-in task for Multimodal LLMs."""
2 |
3 | from __future__ import annotations
4 |
5 | from cornserve.task.base import Task, TaskInput, TaskOutput
6 | from cornserve.task.builtins.encoder import EncoderInput, EncoderTask, Modality
7 | from cornserve.task.builtins.llm import LLMInput, LLMTask
8 | from cornserve.task.forward import DataForward, Tensor
9 |
10 |
11 | class MLLMInput(TaskInput):
12 | """Input model for Multimodal LLM tasks.
13 |
14 | Attributes:
15 | prompt: The prompt to send to the LLM.
16 | multimodal_data: List of tuples (modality, data URL).
17 | "image", "audio", "video", etc. for modality.
18 | """
19 |
20 | prompt: str
21 | multimodal_data: list[tuple[str, str]] = []
22 |
23 |
24 | class MLLMOutput(TaskOutput):
25 | """Output model for Multimodal LLM tasks.
26 |
27 | Attributes:
28 | response: The response from the LLM.
29 | """
30 |
31 | response: str
32 |
33 |
34 | class MLLMTask(Task[MLLMInput, MLLMOutput]):
35 | """A task that invokes a Multimodal LLM.
36 |
37 | Attributes:
38 | model_id: The ID of the model to use for the task.
39 | modalities: List of input modalities other than text.
40 | """
41 |
42 | model_id: str
43 | modalities: list[Modality] = []
44 |
45 | def post_init(self) -> None:
46 | """Initialize subtasks."""
47 | if Modality.IMAGE in self.modalities:
48 | self.image_encoder = EncoderTask(model_id=self.model_id, modality=Modality.IMAGE)
49 | if Modality.VIDEO in self.modalities:
50 | self.video_encoder = EncoderTask(model_id=self.model_id, modality=Modality.VIDEO)
51 | self.llm = LLMTask(model_id=self.model_id)
52 |
53 | def invoke(self, task_input: MLLMInput) -> MLLMOutput:
54 | """Invoke the task.
55 |
56 | Given multimodal data and a text prompt, run the corresponding encoder
57 | for multimodal data and then pass the embeddings and text prompt to the LLM.
58 | """
59 | image_data = []
60 | video_data = []
61 | for modality, data in task_input.multimodal_data:
62 | if modality == Modality.IMAGE:
63 | image_data.append(data)
64 | elif modality == Modality.VIDEO:
65 | video_data.append(data)
66 | else:
67 | raise ValueError(f"Unsupported modality: {modality}")
68 |
69 | if image_data:
70 | if not hasattr(self, "image_encoder"):
71 | raise ValueError("Image modality is not supported.")
72 | image_task_input = EncoderInput(data_urls=image_data)
73 | image_embeddings = self.image_encoder.invoke(image_task_input).embeddings
74 | else:
75 | image_embeddings = []
76 |
77 | if video_data:
78 | if not hasattr(self, "video_encoder"):
79 | raise ValueError("Video modality is not supported.")
80 | video_task_input = EncoderInput(data_urls=video_data)
81 | video_embeddings = self.video_encoder.invoke(video_task_input).embeddings
82 | else:
83 | video_embeddings = []
84 |
85 | # Retain the order of multimodal data
86 | embeddings: list[DataForward[Tensor]] = []
87 | for modality, _ in task_input.multimodal_data:
88 | if modality == Modality.IMAGE:
89 | embeddings.append(image_embeddings.pop(0))
90 | elif modality == Modality.VIDEO:
91 | embeddings.append(video_embeddings.pop(0))
92 |
93 | llm_input = LLMInput(
94 | prompt=task_input.prompt,
95 | multimodal_data=task_input.multimodal_data,
96 | embeddings=embeddings,
97 | )
98 | llm_output = self.llm.invoke(llm_input)
99 |
100 | return MLLMOutput(response=llm_output.response)
101 |
--------------------------------------------------------------------------------
/python/cornserve/task/forward.py:
--------------------------------------------------------------------------------
1 | """Classes for representing data forwarding between tasks in the data plane."""
2 |
3 | from __future__ import annotations
4 |
5 | import enum
6 | import uuid
7 | from typing import Generic, Self, TypeVar
8 |
9 | from pydantic import BaseModel, Field, model_validator
10 |
11 |
12 | class Tensor:
13 | """Represents a tensor object for data forwarding."""
14 |
15 |
16 | class ForwardableType(enum.StrEnum):
17 | """Types of data that can be forwarded between tasks."""
18 |
19 | BYTES = "bytes"
20 | STR = "str"
21 | INT = "int"
22 | FLOAT = "float"
23 | BOOL = "bool"
24 | TENSOR = "Tensor"
25 |
26 |
27 | DataT = TypeVar("DataT")
28 |
29 |
30 | class DataForward(BaseModel, Generic[DataT]):
31 | """Represents data that is forwarded between tasks in the data plane."""
32 |
33 | # This ID identifies `DataForward` objects and ties them together in task input/outputs.
34 | id: str = Field(default_factory=lambda: uuid.uuid4().hex)
35 |
36 | # The data type automatically parsed out of the generic type argument.
37 | data_type: ForwardableType = Field(init=False, default=ForwardableType.TENSOR)
38 |
39 | # Producer (source) sidecar ranks.
40 | src_sidecar_ranks: list[int] | None = Field(init=False, default=None)
41 |
42 | # Consumer (destination) sidecar ranks. This is a list of lists because the data
43 | # can be forwarded to multiple tasks (i.e., broadcasted) to more than one task executor.
44 | dst_sidecar_ranks: list[list[int]] | None = Field(init=False, default=None)
45 |
46 | @model_validator(mode="after")
47 | def _data_type(self) -> Self:
48 | """Validate the generic type argument.
49 |
50 | 1. The generic type argument must be present.
51 | 2. It should be one of the forwardable types (`ForwardableType`).
52 | """
53 | metadata = self.__class__.__pydantic_generic_metadata__
54 | if metadata["origin"] is None:
55 | raise ValueError("Generic type argument is missing.")
56 |
57 | args = metadata["args"]
58 | assert len(args) == 1, "If origin was not None, there should be exactly one argument."
59 | self.data_type = ForwardableType(args[0].__name__)
60 |
61 | return self
62 |
--------------------------------------------------------------------------------
/python/cornserve/task/registry.py:
--------------------------------------------------------------------------------
1 | """Tasks registered and known to the system."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | if TYPE_CHECKING:
8 | from cornserve.task.base import TaskInput, TaskOutput, UnitTask
9 |
10 |
11 | class TaskRegistry:
12 | """Registry of unit tasks.
13 |
14 | Composite tasks are not registered here; they can simply be invoked
15 | and will be decomposed into a series of unit task invocations.
16 | """
17 |
18 | def __init__(self) -> None:
19 | """Initialize the task registry."""
20 | self._tasks: dict[str, tuple[type[UnitTask], type[TaskInput], type[TaskOutput]]] = {}
21 |
22 | def register(
23 | self,
24 | task: type[UnitTask],
25 | task_input: type[TaskInput],
26 | task_output: type[TaskOutput],
27 | name: str | None = None,
28 | ) -> None:
29 | """Register a task with the given ID."""
30 | name = name or task.__name__
31 | if name in self._tasks:
32 | raise ValueError(f"Unit task with {name=} already exists. Unit task names must be unique.")
33 | self._tasks[name] = (task, task_input, task_output)
34 |
35 | def get(self, name: str) -> tuple[type[UnitTask], type[TaskInput], type[TaskOutput]]:
36 | """Get a task by its name."""
37 | # Lazy import builtin tasks to avoid circular import issues
38 | import cornserve.task.builtins # noqa: F401
39 |
40 | if name not in self._tasks:
41 | raise KeyError(f"Unit task with {name=} not found")
42 | return self._tasks[name]
43 |
44 | def __contains__(self, name: str) -> bool:
45 | """Check if a task is registered."""
46 | return name in self._tasks
47 |
48 |
49 | TASK_REGISTRY = TaskRegistry()
50 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/__init__.py:
--------------------------------------------------------------------------------
1 | """Built-in task executors in Cornserve.
2 |
3 | - Eric: A server that embeds modality data so that they can be fed into LLMs.
4 | """
5 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/descriptor/__init__.py:
--------------------------------------------------------------------------------
1 | """Task Execution Descriptor.
2 |
3 | The `TaskExecutionDescriptor` class describes *how* to execute a `Task`.
4 | The same `Task` may have multiple compatible `TaskExecutionDescriptor`s.
5 | For instance, the builtin `LLMTask` can be executed with a monolithic
6 | vLLM instance, but can also be executed with prefill-decode disaggregation.
7 |
8 | A descriptor is compatible with a `Task` when it inherits from the base
9 | descriptor class annotated in the `Task` class's `execution_descriptor` field.
10 |
11 | The descriptor exposes the following information:
12 | - Dominant resource type: GPU, CPU, memory, disk, etc.
13 | This is not implemented yet; all task executors consume GPU resources.
14 | - Chunking and pipelining semantics: Information about what kind of chunking
15 | and pipelining is supported by the task executor for its input and output.
16 | - Launch information: Information about how to launch the task executor.
17 | - Request and response transformation: How to transform TaskInput to the actual
18 | Task Executor request object, and the Task Executor response object to the
19 | TaskOutput object.
20 | """
21 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/descriptor/base.py:
--------------------------------------------------------------------------------
1 | """Base task execution descriptor class."""
2 |
3 | from __future__ import annotations
4 |
5 | from abc import ABC, abstractmethod
6 | from typing import Any, Generic, TypeVar
7 |
8 | from pydantic import BaseModel
9 |
10 | from cornserve import constants
11 | from cornserve.services.resource_manager.resource import GPU
12 | from cornserve.task.base import TaskInput, TaskOutput, UnitTask
13 |
14 | TaskT = TypeVar("TaskT", bound=UnitTask)
15 | InputT = TypeVar("InputT", bound=TaskInput)
16 | OutputT = TypeVar("OutputT", bound=TaskOutput)
17 |
18 |
19 | class TaskExecutionDescriptor(BaseModel, ABC, Generic[TaskT, InputT, OutputT]):
20 | """Base class for task execution descriptors.
21 |
22 | Attributes:
23 | task: The task to be executed.
24 | """
25 |
26 | task: TaskT
27 |
28 | @abstractmethod
29 | def create_executor_name(self) -> str:
30 | """Create a name for the task executor."""
31 |
32 | @abstractmethod
33 | def get_container_image(self) -> str:
34 | """Get the container image name for the task executor."""
35 |
36 | @abstractmethod
37 | def get_container_args(self, gpus: list[GPU], port: int) -> list[str]:
38 | """Get the container command for the task executor."""
39 |
40 | def get_container_volumes(self) -> list[tuple[str, str, str]]:
41 | """Get the container volumes for the task manager.
42 |
43 | Returns:
44 | A list of tuples: name, host path, container path.
45 | """
46 | return [
47 | ("hf-cache", constants.VOLUME_HF_CACHE, "/root/.cache/huggingface"),
48 | ("shm", constants.VOLUME_SHM, "/dev/shm"),
49 | ]
50 |
51 | @abstractmethod
52 | def get_api_url(self, base: str) -> str:
53 | """Get the task executor's base URL for API calls.
54 |
55 | Args:
56 | base: The base URL of the task executor.
57 | """
58 |
59 | @abstractmethod
60 | def to_request(self, task_input: InputT, task_output: OutputT) -> dict[str, Any]:
61 | """Convert TaskInput to a request object for the task executor.
62 |
63 | The task output object is needed because this specific task executor may
64 | have to forward data to the next task executor, and for that, we need to
65 | know the destination sidecar ranks annotated in the task output.
66 | """
67 |
68 | @abstractmethod
69 | def from_response(self, task_output: OutputT, response: dict[str, Any]) -> OutputT:
70 | """Convert the task executor response to TaskOutput.
71 |
72 | In general, the `task_output` object will be deep-copied and concrete values
73 | will be filled in from the response.
74 | """
75 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/descriptor/builtins/__init__.py:
--------------------------------------------------------------------------------
1 | """Built-in task execution descriptors."""
2 |
3 | from cornserve.task_executors.descriptor.builtins import encoder, llm
4 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/descriptor/builtins/encoder.py:
--------------------------------------------------------------------------------
1 | """Built-in task execution descriptor for Encoder tasks."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Any
6 |
7 | from cornserve import constants
8 | from cornserve.services.resource_manager.resource import GPU
9 | from cornserve.task.builtins.encoder import EncoderInput, EncoderOutput, EncoderTask
10 | from cornserve.task_executors.descriptor.base import TaskExecutionDescriptor
11 | from cornserve.task_executors.descriptor.registry import DESCRIPTOR_REGISTRY
12 | from cornserve.task_executors.eric.api import EmbeddingData, EmbeddingRequest, EmbeddingResponse, Modality, Status
13 |
14 |
15 | class EricDescriptor(TaskExecutionDescriptor[EncoderTask, EncoderInput, EncoderOutput]):
16 | """Task execution descriptor for Encoder tasks.
17 |
18 | This descriptor handles launching Eric (multimodal encoder) tasks and converting between
19 | the external task API types and internal executor types.
20 | """
21 |
22 | def create_executor_name(self) -> str:
23 | """Create a name for the task executor."""
24 | name = "-".join(
25 | [
26 | "eric",
27 | self.task.modality,
28 | self.task.model_id.split("/")[-1].lower(),
29 | ]
30 | ).lower()
31 | return name
32 |
33 | def get_container_image(self) -> str:
34 | """Get the container image name for the task executor."""
35 | return constants.CONTAINER_IMAGE_ERIC
36 |
37 | def get_container_args(self, gpus: list[GPU], port: int) -> list[str]:
38 | """Get the container command for the task executor."""
39 | # fmt: off
40 | cmd = [
41 | "--model.id", self.task.model_id,
42 | "--model.tp-size", str(len(gpus)),
43 | "--model.modality", self.task.modality.value.upper(),
44 | "--server.port", str(port),
45 | "--sidecar.ranks", *[str(gpu.global_rank) for gpu in gpus],
46 | ]
47 | # fmt: on
48 | return cmd
49 |
50 | def get_api_url(self, base: str) -> str:
51 | """Get the task executor's base URL for API calls."""
52 | return f"{base}/embeddings"
53 |
54 | def to_request(self, task_input: EncoderInput, task_output: EncoderOutput) -> dict[str, Any]:
55 | """Convert TaskInput to a request object for the task executor."""
56 | data: list[EmbeddingData] = []
57 | for url, forward in zip(task_input.data_urls, task_output.embeddings, strict=True):
58 | if forward.dst_sidecar_ranks is None:
59 | raise ValueError("Destination sidecar ranks must be specified for each forward.")
60 | data.append(
61 | EmbeddingData(
62 | id=forward.id,
63 | modality=Modality(self.task.modality.value),
64 | url=url,
65 | receiver_sidecar_ranks=forward.dst_sidecar_ranks,
66 | )
67 | )
68 | req = EmbeddingRequest(data=data)
69 | return req.model_dump()
70 |
71 | def from_response(self, task_output: EncoderOutput, response: dict[str, Any]) -> EncoderOutput:
72 | """Convert the task executor response to TaskOutput."""
73 | resp = EmbeddingResponse.model_validate(response)
74 | if resp.status == Status.SUCCESS:
75 | return EncoderOutput(embeddings=task_output.embeddings)
76 | else:
77 | raise RuntimeError(f"Error in encoder task: {resp.error_message}")
78 |
79 |
80 | DESCRIPTOR_REGISTRY.register(EncoderTask, EricDescriptor, default=True)
81 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/descriptor/builtins/llm.py:
--------------------------------------------------------------------------------
1 | """Built-in task execution descriptor for LLM tasks."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Any
6 |
7 | from cornserve import constants
8 | from cornserve.services.resource_manager.resource import GPU
9 | from cornserve.task.builtins.llm import LLMBaseTask, LLMForwardOutput, LLMInput, LLMOutput, LLMOutputBase
10 | from cornserve.task_executors.descriptor.base import TaskExecutionDescriptor
11 | from cornserve.task_executors.descriptor.registry import DESCRIPTOR_REGISTRY
12 |
13 |
14 | class VLLMDescriptor(TaskExecutionDescriptor[LLMBaseTask, LLMInput, LLMOutputBase]):
15 | """Task execution descriptor for Encoder tasks.
16 |
17 | This descriptor handles launching Eric (multimodal encoder) tasks and converting between
18 | the external task API types and internal executor types.
19 | """
20 |
21 | def create_executor_name(self) -> str:
22 | """Create a name for the task executor."""
23 | return "-".join(["vllm", self.task.model_id.split("/")[-1]]).lower()
24 |
25 | def get_container_image(self) -> str:
26 | """Get the container image name for the task executor."""
27 | return constants.CONTAINER_IMAGE_VLLM
28 |
29 | def get_container_args(self, gpus: list[GPU], port: int) -> list[str]:
30 | """Get the container command for the task executor."""
31 | # fmt: off
32 | cmd = [
33 | self.task.model_id,
34 | "--tensor-parallel-size", str(len(gpus)),
35 | "--port", str(port),
36 | "--limit-mm-per-prompt", "image=5", # TODO: Is this still needed?
37 | "--cornserve-sidecar-ranks", *[str(gpu.global_rank) for gpu in gpus],
38 | ]
39 | # fmt: on
40 | return cmd
41 |
42 | def get_api_url(self, base: str) -> str:
43 | """Get the task executor's base URL for API calls."""
44 | return f"{base}/v1/chat/completions"
45 |
46 | def to_request(self, task_input: LLMInput, task_output: LLMOutputBase) -> dict[str, Any]:
47 | """Convert TaskInput to a request object for the task executor."""
48 | # XXX: `DataForward[str]` not supported yet.
49 | # Compatibility with OpenAI Chat Completion API is kept.
50 | content: list[dict[str, Any]] = [dict(type="text", text=task_input.prompt)]
51 | for (modality, data_url), forward in zip(task_input.multimodal_data, task_input.embeddings, strict=True):
52 | data_uri = f"data:{modality}/uuid;data_id={forward.id};url={data_url},"
53 | content.append({"type": f"{modality}_url", f"{modality}_url": {"url": data_uri}})
54 |
55 | request = dict(
56 | model=self.task.model_id,
57 | messages=[dict(role="user", content=content)],
58 | max_completion_tokens=512,
59 | )
60 |
61 | return request
62 |
63 | def from_response(self, task_output: LLMOutputBase, response: dict[str, Any]) -> LLMOutputBase:
64 | """Convert the task executor response to TaskOutput."""
65 | if isinstance(task_output, LLMOutput):
66 | return LLMOutput(response=response["choices"][0]["message"]["content"])
67 | if isinstance(task_output, LLMForwardOutput):
68 | return LLMForwardOutput(
69 | response=task_output.response,
70 | )
71 | raise ValueError(f"Unexpected task output type: {type(task_output)}")
72 |
73 |
74 | DESCRIPTOR_REGISTRY.register(LLMBaseTask, VLLMDescriptor, default=True)
75 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/descriptor/registry.py:
--------------------------------------------------------------------------------
1 | """Task execution descriptor registry.
2 |
3 | Task execution descriptor classes register themselves to the registry
4 | specifying a task they can execute. Exactly one descriptor class per
5 | task is marked as the default descriptor class. The registry is used
6 | to look up the descriptor class for a task when executing it.
7 |
8 | `DESCRIPTOR_REGISTRY` is a singleton instance of the registry.
9 | """
10 |
11 | from __future__ import annotations
12 |
13 | from collections import defaultdict
14 | from typing import TYPE_CHECKING
15 |
16 | if TYPE_CHECKING:
17 | from cornserve.task.base import UnitTask
18 | from cornserve.task_executors.descriptor.base import TaskExecutionDescriptor
19 |
20 | DEFAULT = "__default_descriptor__"
21 |
22 |
23 | class TaskExecutionDescriptorRegistry:
24 | """Registry for task execution descriptors."""
25 |
26 | def __init__(self) -> None:
27 | """Initialize the registry."""
28 | self.registry: dict[type[UnitTask], dict[str, type[TaskExecutionDescriptor]]] = defaultdict(dict)
29 | self.default_registry: dict[type[UnitTask], type[TaskExecutionDescriptor]] = {}
30 |
31 | def register(
32 | self,
33 | task: type[UnitTask],
34 | descriptor: type[TaskExecutionDescriptor],
35 | name: str | None = None,
36 | default: bool = False,
37 | ) -> None:
38 | """Register a task execution descriptor.
39 |
40 | Args:
41 | task: The task class to register the descriptor for.
42 | descriptor: The task execution descriptor class.
43 | name: The name of the descriptor. If None, use the class name.
44 | default: Whether this is the default descriptor for the task.
45 | """
46 | if name is None:
47 | name = descriptor.__name__
48 |
49 | if name in self.registry[task]:
50 | raise ValueError(f"Descriptor {name} already registered for task {task.__name__}")
51 |
52 | self.registry[task][name] = descriptor
53 |
54 | if default:
55 | if task in self.default_registry:
56 | raise ValueError(f"Default descriptor already registered for task {task.__name__}")
57 | self.default_registry[task] = descriptor
58 |
59 | def get(self, task: type[UnitTask], name: str | None = None) -> type[TaskExecutionDescriptor]:
60 | """Get the task execution descriptor for a task.
61 |
62 | Args:
63 | task: The task class to get the descriptor for.
64 | name: The name of the descriptor. If None, use the default descriptor.
65 | """
66 | # Lazily import built-in descriptors to avoid circular imports
67 | import cornserve.task_executors.descriptor.builtins # noqa: F401
68 |
69 | if task not in self.registry:
70 | raise ValueError(f"No descriptors registered for task {task.__name__}")
71 |
72 | if name is None:
73 | if task not in self.default_registry:
74 | raise ValueError(f"No default descriptor registered for task {task.__name__}")
75 | return self.default_registry[task]
76 |
77 | if name not in self.registry[task]:
78 | raise ValueError(f"Descriptor {name} not registered for task {task.__name__}")
79 |
80 | return self.registry[task][name]
81 |
82 |
83 | DESCRIPTOR_REGISTRY = TaskExecutionDescriptorRegistry()
84 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/__init__.py:
--------------------------------------------------------------------------------
1 | """Eric embeds modality data so that they can be fed into LLMs."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/api.py:
--------------------------------------------------------------------------------
1 | """API schema for Eric."""
2 |
3 | from __future__ import annotations
4 |
5 | import enum
6 |
7 | from pydantic import BaseModel
8 |
9 | ID = str
10 |
11 |
12 | class Modality(enum.StrEnum):
13 | """Modality of the data to be embedded."""
14 |
15 | IMAGE = "image"
16 | VIDEO = "video"
17 | AUDIO = "audio"
18 |
19 |
20 | class EmbeddingData(BaseModel):
21 | """The data to be embedded.
22 |
23 | Attributes:
24 | id: Modality data ID unique within the request.
25 | modality: The modality of the data.
26 | url: The URL where the data can be downloaded from.
27 | receiver_sidecar_ranks: List of sidecar ranks to send the embeddings to.
28 | If omitted, tensors will not be sent to any sidecar.
29 | """
30 |
31 | id: ID
32 | modality: Modality
33 | url: str
34 | receiver_sidecar_ranks: list[list[int]] | None = None
35 |
36 |
37 | class EmbeddingRequest(BaseModel):
38 | """Request to embed data.
39 |
40 | Attributes:
41 | data: List of data to be embedded.
42 | """
43 |
44 | data: list[EmbeddingData]
45 |
46 |
47 | class Status(enum.IntEnum):
48 | """Status of various operations."""
49 |
50 | SUCCESS = 0
51 | ERROR = 1
52 |
53 |
54 | class EmbeddingResponse(BaseModel):
55 | """Response containing the embedding."""
56 |
57 | status: Status
58 | error_message: str | None = None
59 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/engine/__init__.py:
--------------------------------------------------------------------------------
1 | """The engine runs in a separate process and handles model inference."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/entrypoint.py:
--------------------------------------------------------------------------------
1 | """Spins up Eric."""
2 |
3 | from __future__ import annotations
4 |
5 | import asyncio
6 | import signal
7 |
8 | import tyro
9 | import uvicorn
10 | from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
11 | from opentelemetry.instrumentation.threading import ThreadingInstrumentor
12 |
13 | from cornserve.logging import get_logger
14 | from cornserve.task_executors.eric.config import EricConfig
15 | from cornserve.task_executors.eric.engine.client import EngineClient
16 | from cornserve.task_executors.eric.router.app import create_app
17 | from cornserve.tracing import configure_otel
18 |
19 | logger = get_logger("cornserve.task_executors.eric.entrypoint")
20 |
21 |
22 | async def serve(eric_config: EricConfig) -> None:
23 | """Serve the Eric model as a FastAPI app."""
24 | logger.info("Starting Eric with %s", eric_config)
25 |
26 | configure_otel(f"eric{str(eric_config.sidecar.ranks).replace(' ', '')}")
27 |
28 | app = create_app(eric_config)
29 |
30 | FastAPIInstrumentor().instrument_app(app)
31 | ThreadingInstrumentor().instrument()
32 |
33 | logger.info("Available routes are:")
34 | for route in app.routes:
35 | methods = getattr(route, "methods", None)
36 | path = getattr(route, "path", None)
37 |
38 | if methods is None or path is None:
39 | continue
40 |
41 | logger.info(
42 | "%s %s",
43 | list(methods)[0] if len(methods) == 1 else "{" + ",".join(methods) + "}",
44 | path,
45 | )
46 |
47 | config = uvicorn.Config(app, host=eric_config.server.host, port=eric_config.server.port)
48 | server = uvicorn.Server(config)
49 |
50 | loop = asyncio.get_running_loop()
51 | server_task = loop.create_task(server.serve())
52 |
53 | def shutdown() -> None:
54 | engine_client: EngineClient = app.state.engine_client
55 | engine_client.shutdown()
56 | server_task.cancel()
57 |
58 | loop.add_signal_handler(signal.SIGINT, shutdown)
59 | loop.add_signal_handler(signal.SIGTERM, shutdown)
60 |
61 | try:
62 | await server_task
63 | except asyncio.CancelledError:
64 | logger.info("Shutting down FastAPI server.")
65 | await server.shutdown()
66 |
67 |
68 | if __name__ == "__main__":
69 | asyncio.run(serve(tyro.cli(EricConfig)))
70 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/executor/__init__.py:
--------------------------------------------------------------------------------
1 | """The executor manages workers that execute inference."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/models/__init__.py:
--------------------------------------------------------------------------------
1 | """The Eric model zoo.
2 |
3 | Steps for adding a new model:
4 |
5 | 1. Create a module inside `models` named after the model type (`hf_config.model_type`).
6 | 2. Implement the model class inheriting from `models.base.EricModel`.
7 | 3. Implement a class exactly called `ModalityProcessor` in the module, inheriting from
8 | `models.base.BaseModalityProcessor`. For each supported modality, implement
9 | the corresponding method (`get_image_processor`, `get_video_processor`, etc.).
10 | 4. Add an entry in `models.registry.MODEL_REGISTRY` for the model type.
11 | """
12 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/models/base.py:
--------------------------------------------------------------------------------
1 | """Base class for all models in Eric."""
2 |
3 | from __future__ import annotations
4 |
5 | from abc import ABC, abstractmethod
6 | from typing import Callable
7 |
8 | import torch
9 | import torch.nn as nn
10 | import numpy.typing as npt
11 |
12 | from cornserve.task_executors.eric.schema import Modality
13 |
14 |
15 | class EricModel(nn.Module, ABC):
16 | """Base class for all models in Eric."""
17 |
18 | @abstractmethod
19 | def forward(self, modality: Modality, batch: dict[str, list[torch.Tensor]]) -> list[torch.Tensor]:
20 | """Forward pass for the model.
21 |
22 | Args:
23 | modality: The modality of the data.
24 | batch: The input data.
25 |
26 | Returns:
27 | A list of output tensors.
28 | """
29 |
30 | @property
31 | @abstractmethod
32 | def dtype(self) -> torch.dtype:
33 | """Return the data type of the model's embeddings."""
34 |
35 | @property
36 | @abstractmethod
37 | def device(self) -> torch.device:
38 | """Return the device where inputs should be in."""
39 |
40 | @property
41 | @abstractmethod
42 | def chunk_shape(self) -> tuple[int, ...]:
43 | """Return the shape of the chunks to be sent to the sidecar."""
44 |
45 |
46 | class BaseModalityProcessor:
47 | """Base class for modality processors.
48 |
49 | Each model definition module contains a `ModalityProcessor` class that
50 | inherits from this class. It should override `get_image_processor`,
51 | `get_video_processor`, etc. to return the appropriate processor for the
52 | given modality. The processor should be a callable that takes the input
53 | modality data as a Numpy array and returns the processed data as a
54 | dictionary of Numpy arrays.
55 | """
56 |
57 | def __init__(self, model_id: str) -> None:
58 | """Initialize the processor."""
59 | self.model_id = model_id
60 |
61 | def get_image_processor(self) -> Callable | None:
62 | """Get the image processor for this modality.
63 |
64 | The callable sould take a single image numpy array.
65 | """
66 | return None
67 |
68 | def get_audio_processor(self) -> Callable | None:
69 | """Get the audio processor for this modality.
70 |
71 | The callable should take a tuple of (audio data numpy array, sample rate).
72 | """
73 | return None
74 |
75 | def get_video_processor(self) -> Callable | None:
76 | """Get the video processor for this modality.
77 |
78 | The callable should take a single video numpy array.
79 | """
80 | return None
81 |
82 | def process(self, modality: Modality, data: npt.NDArray) -> dict[str, npt.NDArray]:
83 | """Process the input data for the given modality."""
84 | match modality:
85 | case Modality.IMAGE:
86 | image_processor = self.get_image_processor()
87 | if image_processor is None:
88 | raise ValueError("Image processor not available.")
89 | return image_processor(data)
90 | case Modality.AUDIO:
91 | audio_processor = self.get_audio_processor()
92 | if audio_processor is None:
93 | raise ValueError("Audio processor not available.")
94 | return audio_processor(data)
95 | case Modality.VIDEO:
96 | video_processor = self.get_video_processor()
97 | if video_processor is None:
98 | raise ValueError("Video processor not available.")
99 | return video_processor(data)
100 | case _:
101 | raise ValueError(f"Unsupported modality: {modality}")
102 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | """A collection of reusable layers."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/models/layers/activations.py:
--------------------------------------------------------------------------------
1 | """Activation functions."""
2 |
3 | from __future__ import annotations
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class QuickGELU(nn.Module):
10 | # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
11 | def forward(self, x: torch.Tensor) -> torch.Tensor:
12 | return x * torch.sigmoid(1.702 * x)
13 |
14 |
15 | def get_act_fn(name: str) -> nn.Module:
16 | """Get an activation function by name."""
17 | match name.lower():
18 | case "gelu":
19 | return nn.GELU()
20 | case "quick_gelu":
21 | return QuickGELU()
22 | case "gelu_pytorch_tanh":
23 | return nn.GELU(approximate="tanh")
24 | case "relu":
25 | return nn.ReLU()
26 | case "silu":
27 | return nn.SiLU()
28 | case _:
29 | raise NotImplementedError(f"Activation function {name!r} is not implemented.")
30 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/models/layers/attention.py:
--------------------------------------------------------------------------------
1 | """Generic attention implementation for ViTs."""
2 |
3 | from __future__ import annotations
4 |
5 | import torch
6 | import torch.nn as nn
7 | from xformers.ops import memory_efficient_attention_forward # type: ignore
8 |
9 |
10 | class Attention(nn.Module):
11 | """Full attention implementation for ViTs."""
12 |
13 | def __init__(
14 | self,
15 | num_heads: int,
16 | head_size: int,
17 | scale: float,
18 | num_kv_heads: int | None = None,
19 | ) -> None:
20 | super().__init__()
21 | self.num_heads = num_heads
22 | self.head_size = head_size
23 | self.scale = scale
24 | self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
25 | self.num_queries_per_kv, rem = divmod(self.num_heads, self.num_kv_heads)
26 | assert rem == 0, f"{num_heads=} must be divisible by {num_kv_heads=}"
27 |
28 | def forward(
29 | self,
30 | query: torch.Tensor,
31 | key: torch.Tensor,
32 | value: torch.Tensor,
33 | ) -> torch.Tensor:
34 | """Run forward.
35 |
36 | Args:
37 | query, key, value: [batch_size, seq_len, hidden_size]
38 | """
39 | bsz, q_len, _ = query.size()
40 | kv_len = key.size(1)
41 |
42 | query = query.view(bsz, q_len, self.num_heads, self.head_size)
43 | key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
44 | value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
45 |
46 | if (num_repeat := self.num_queries_per_kv) > 1:
47 | # Handle MQA and GQA
48 | key = torch.repeat_interleave(key, num_repeat, dim=2)
49 | value = torch.repeat_interleave(value, num_repeat, dim=2)
50 |
51 | out = memory_efficient_attention_forward(query, key, value, scale=self.scale)
52 |
53 | return out.reshape(bsz, q_len, -1)
54 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/models/layers/vit.py:
--------------------------------------------------------------------------------
1 | """Vision Transformer."""
2 |
3 | from __future__ import annotations
4 |
5 | import torch.nn as nn
6 | from transformers.models.siglip import SiglipVisionConfig
7 |
8 | from cornserve.task_executors.eric.models.layers.siglip import SiglipVisionModel
9 |
10 |
11 | def _get_num_hidden_layers(hf_config) -> int:
12 | """Determine the number of hidden layers to initialize up to in the
13 | visual encoder.
14 |
15 | Args:
16 | hf_config: Model config with vision feature layer(s).
17 | """
18 | feature_layers = hf_config.vision_feature_layer
19 | num_hidden_layers = hf_config.vision_config.num_hidden_layers
20 | # If we have one feature layer, initialize up to that layer
21 | if isinstance(feature_layers, int):
22 | return _get_layer_index(feature_layers, num_hidden_layers)
23 | # If we have multiple feature layers, initialize up to the deepest one
24 | elif isinstance(feature_layers, (list, tuple)):
25 | return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
26 | raise TypeError(f"vision_layer_feature type: {type(feature_layers)} is not supported")
27 |
28 |
29 | def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
30 | """Given a signed vision feature layer, get the number of hidden layers
31 | needed to leverage it.
32 |
33 | Args:
34 | feature_layer_index: Index of a required layer in the visual encoder.
35 | num_hidden_layers: The total number of hidden layers in the visual
36 | encoder.
37 | """
38 | if feature_layer_index < 0:
39 | return num_hidden_layers + feature_layer_index + 1
40 | return feature_layer_index
41 |
42 |
43 | def init_vision_tower_for_llava(
44 | hf_config,
45 | *,
46 | require_post_norm: bool | None = None,
47 | ) -> nn.Module:
48 | """Initialize the vision tower for Llava-family models."""
49 | vision_config = hf_config.vision_config
50 |
51 | # Initialize the vision tower only up to the deepest required feature layer
52 | num_hidden_layers = _get_num_hidden_layers(hf_config)
53 |
54 | if isinstance(vision_config, SiglipVisionConfig):
55 | return SiglipVisionModel(
56 | vision_config,
57 | num_hidden_layers_override=num_hidden_layers,
58 | require_post_norm=require_post_norm,
59 | )
60 |
61 | msg = f"Unsupported vision config: {type(vision_config)}"
62 | raise NotImplementedError(msg)
63 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/router/__init__.py:
--------------------------------------------------------------------------------
1 | """The router receives embedding requests and invokes the engine."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/router/app.py:
--------------------------------------------------------------------------------
1 | """Eric FastAPI app definition."""
2 |
3 | from __future__ import annotations
4 |
5 | import uuid
6 |
7 | from fastapi import APIRouter, FastAPI, Request, Response, status
8 | from opentelemetry import trace
9 |
10 | from cornserve.logging import get_logger
11 | from cornserve.task_executors.eric.api import EmbeddingRequest, EmbeddingResponse, Modality, Status
12 | from cornserve.task_executors.eric.config import EricConfig
13 | from cornserve.task_executors.eric.engine.client import EngineClient
14 | from cornserve.task_executors.eric.models.registry import MODEL_REGISTRY
15 | from cornserve.task_executors.eric.router.processor import Processor
16 |
17 | router = APIRouter()
18 | logger = get_logger(__name__)
19 | tracer = trace.get_tracer(__name__)
20 |
21 |
22 | @router.get("/health")
23 | async def health_check(request: Request) -> Response:
24 | """Checks whether the router and the engine are alive."""
25 | engine_client: EngineClient = request.app.state.engine_client
26 | match engine_client.health_check():
27 | case True:
28 | return Response(status_code=status.HTTP_200_OK)
29 | case False:
30 | return Response(status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
31 |
32 |
33 | @router.get("/info")
34 | async def info(raw_request: Request) -> EricConfig:
35 | """Returns Eric's configuration information."""
36 | return raw_request.app.state.config
37 |
38 |
39 | @router.get("/modalities")
40 | async def modalities(raw_request: Request) -> list[Modality]:
41 | """Return the list of modalities supported by this model."""
42 | config: EricConfig = raw_request.app.state.config
43 | return list(MODEL_REGISTRY[config.model.hf_config.model_type].modality.keys())
44 |
45 |
46 | @router.post("/embeddings")
47 | async def embeddings(
48 | request: EmbeddingRequest,
49 | raw_request: Request,
50 | raw_response: Response,
51 | ) -> EmbeddingResponse:
52 | """Handler for embedding requests."""
53 | span = trace.get_current_span()
54 | for data_item in request.data:
55 | span.set_attribute(
56 | f"eric.embeddings.data.{data_item.id}.url",
57 | data_item.url,
58 | )
59 | processor: Processor = raw_request.app.state.processor
60 | engine_client: EngineClient = raw_request.app.state.engine_client
61 |
62 | # Load data from URLs and apply processing
63 | processed = await processor.process(request.data)
64 |
65 | # Send to engine process (embedding + transmission via Tensor Sidecar)
66 | response = await engine_client.embed(uuid.uuid4().hex, processed)
67 |
68 | match response.status:
69 | case Status.SUCCESS:
70 | raw_response.status_code = status.HTTP_200_OK
71 | case Status.ERROR:
72 | raw_response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
73 | case _:
74 | logger.error("Unexpected status: %s", response.status)
75 | raw_response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
76 | return response
77 |
78 |
79 | def init_app_state(app: FastAPI, config: EricConfig) -> None:
80 | """Initialize the app state with the configuration and engine client."""
81 | app.state.config = config
82 | app.state.processor = Processor(config.model.id, config.modality)
83 | app.state.engine_client = EngineClient(config)
84 |
85 |
86 | def create_app(config: EricConfig) -> FastAPI:
87 | """Create a FastAPI app with the given configuration."""
88 | app = FastAPI()
89 | app.include_router(router)
90 | init_app_state(app, config)
91 | return app
92 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utility functions and classes for Eric."""
2 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/utils/distributed.py:
--------------------------------------------------------------------------------
1 | """Utilities for distributed inference."""
2 |
3 | from __future__ import annotations
4 |
5 | from collections.abc import Sequence
6 |
7 | import torch
8 |
9 |
10 | def divide(numerator, denominator):
11 | """Ensure numerator is divisible by the denominator and return the quotient."""
12 | quotient, remainder = divmod(numerator, denominator)
13 | assert remainder == 0, f"{numerator} is not divisible by {denominator}"
14 | return quotient
15 |
16 |
17 | def split_tensor_along_last_dim(
18 | tensor: torch.Tensor,
19 | num_partitions: int,
20 | contiguous_split_chunks: bool = False,
21 | ) -> Sequence[torch.Tensor]:
22 | """Split a tensor along its last dimension.
23 |
24 | Arguments:
25 | tensor: input tensor.
26 | num_partitions: number of partitions to split the tensor
27 | contiguous_split_chunks: If True, make each chunk contiguous.
28 |
29 | Returns:
30 | A list of Tensors
31 | """
32 | # Get the size and dimension.
33 | last_dim = tensor.dim() - 1
34 | last_dim_size = divide(tensor.size()[last_dim], num_partitions)
35 | # Split.
36 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
37 | # NOTE: torch.split does not create contiguous tensors by default.
38 | if contiguous_split_chunks:
39 | return tuple(chunk.contiguous() for chunk in tensor_list)
40 |
41 | return tensor_list
42 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/utils/network.py:
--------------------------------------------------------------------------------
1 | """Utilities for networking."""
2 |
3 | from __future__ import annotations
4 |
5 | import socket
6 |
7 |
8 | def get_open_port() -> int:
9 | """Get an open port on the local machine."""
10 | # Try IPv4 first
11 | try:
12 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
13 | s.bind(("", 0))
14 | return s.getsockname()[1]
15 | except OSError:
16 | # Try IPv6
17 | with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
18 | s.bind(("", 0))
19 | return s.getsockname()[1]
20 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/utils/package.py:
--------------------------------------------------------------------------------
1 | """Utilities for handling optional package imports."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Any
6 |
7 |
8 | class PlaceholderModule:
9 | """A placeholder class that replaces optional packages when they are not installed."""
10 |
11 | def __init__(self, name: str, optional_dependency_name: str) -> None:
12 | """Instantiate a placeholder module with the package's name."""
13 | self.__name = name
14 | self.__optional_dependency_name = optional_dependency_name
15 |
16 | def __getattr__(self, name: str) -> Any:
17 | """Raise an error when any attribute is accessed."""
18 | if name == "__name":
19 | return self.__name
20 |
21 | if name == "__optional_dependency_name":
22 | return self.__optional_dependency_name
23 |
24 | raise RuntimeError(
25 | f"Optional package '{self.__name}' is not installed. Please install "
26 | f"optional dependencies with `pip install cornserve[{self.__optional_dependency_name}]`."
27 | )
28 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/utils/process.py:
--------------------------------------------------------------------------------
1 | """Utilities for process management."""
2 |
3 | from __future__ import annotations
4 |
5 | import contextlib
6 | import os
7 | import signal
8 |
9 | import psutil
10 |
11 |
12 | def kill_process_tree(pid: int | None) -> None:
13 | """Kill all descendant processes of the given pid by sending SIGKILL.
14 |
15 | Args:
16 | pid: Process ID of the parent process.
17 | """
18 | # None might be passed in if mp.Process hasn't been spawned yet
19 | if pid is None:
20 | return
21 |
22 | try:
23 | parent = psutil.Process(pid)
24 | except psutil.NoSuchProcess:
25 | return
26 |
27 | # Get all children recursively
28 | children = parent.children(recursive=True)
29 |
30 | # Send SIGKILL to all children first
31 | for child in children:
32 | with contextlib.suppress(ProcessLookupError):
33 | os.kill(child.pid, signal.SIGKILL)
34 |
35 | # Finally kill the parent
36 | with contextlib.suppress(ProcessLookupError):
37 | os.kill(pid, signal.SIGKILL)
38 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/utils/serde.py:
--------------------------------------------------------------------------------
1 | """Utilities for serializing and deserializing objects."""
2 |
3 | from __future__ import annotations
4 |
5 | import pickle
6 | from typing import Any
7 |
8 | import numpy as np
9 | import torch
10 | from msgspec import msgpack
11 |
12 | CUSTOM_TYPE_NUMPY = 1
13 | CUSTOM_TYPE_TORCH = 2
14 | CUSTOM_TYPE_PICKLE = 3
15 |
16 |
17 | class MsgpackEncoder:
18 | """Msgpack encoder that implements custom serialization."""
19 |
20 | def __init__(self) -> None:
21 | """Initialize the encoder."""
22 | self.encoder = msgpack.Encoder(enc_hook=enc_hook)
23 |
24 | def encode(self, obj: Any) -> bytes:
25 | """Encode an object to bytes."""
26 | return self.encoder.encode(obj)
27 |
28 | def encode_into(self, obj: Any, buffer: bytearray) -> None:
29 | """Encode an object into a buffer."""
30 | self.encoder.encode_into(obj, buffer)
31 |
32 |
33 | class MsgpackDecoder:
34 | """Msgpack decoder that implements custom deserialization."""
35 |
36 | def __init__(self, ty: type | None = None) -> None:
37 | """Initialize the decoder."""
38 | self.decoder = msgpack.Decoder(type=ty, ext_hook=ext_hook)
39 |
40 | def decode(self, data: bytes) -> Any:
41 | """Decode bytes to an object."""
42 | return self.decoder.decode(data)
43 |
44 |
45 | def enc_hook(obj: Any) -> Any:
46 | """Use pickle to serialize Numpy arrays.
47 |
48 | https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103
49 | """
50 | if isinstance(obj, np.ndarray):
51 | return msgpack.Ext(CUSTOM_TYPE_NUMPY, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
52 | if isinstance(obj, torch.Tensor):
53 | # Torch tensors are serialized as Numpy arrays.
54 | return msgpack.Ext(
55 | CUSTOM_TYPE_TORCH,
56 | pickle.dumps(obj.numpy(), protocol=pickle.HIGHEST_PROTOCOL),
57 | )
58 |
59 | return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
60 |
61 |
62 | def ext_hook(code: int, data: memoryview) -> Any:
63 | """Use pickle to deserialize Numpy arrays."""
64 | if code == CUSTOM_TYPE_NUMPY:
65 | return pickle.loads(data)
66 | if code == CUSTOM_TYPE_TORCH:
67 | # Torch tensors are deserialized as Numpy arrays.
68 | return torch.from_numpy(pickle.loads(data))
69 | if code == CUSTOM_TYPE_PICKLE:
70 | return pickle.loads(data)
71 |
72 | raise ValueError(f"Unknown custom serialization code: {code}")
73 |
--------------------------------------------------------------------------------
/python/cornserve/task_executors/eric/utils/zmq.py:
--------------------------------------------------------------------------------
1 | """Utilities for creating and managing ZMQ sockets."""
2 |
3 | from __future__ import annotations
4 |
5 | import contextlib
6 | import tempfile
7 | from collections.abc import Iterator
8 | from typing import overload
9 | from uuid import uuid4
10 |
11 | import zmq
12 | import zmq.asyncio
13 |
14 | from cornserve.logging import get_logger
15 |
16 | logger = get_logger(__name__)
17 |
18 | TMP_DIR = tempfile.gettempdir()
19 |
20 |
21 | @overload
22 | def make_zmq_socket(
23 | ctx: zmq.asyncio.Context,
24 | path: str,
25 | sock_type: int,
26 | ) -> zmq.asyncio.Socket: ...
27 |
28 |
29 | @overload
30 | def make_zmq_socket(
31 | ctx: zmq.Context,
32 | path: str,
33 | sock_type: int,
34 | ) -> zmq.Socket: ...
35 |
36 |
37 | def make_zmq_socket(
38 | ctx: zmq.Context | zmq.asyncio.Context,
39 | path: str,
40 | sock_type: int,
41 | ) -> zmq.Socket | zmq.asyncio.Socket:
42 | """Create a ZMQ socket.
43 |
44 | Args:
45 | ctx: The ZMQ context. Can be either a sync or async context.
46 | path: Socket path prefixed with protocol.
47 | sock_type: Socket type, like `zmq.PULL` or `zmq.PUSH`.
48 | """
49 | s = ctx.socket(sock_type)
50 |
51 | buf_size = int(0.5 * 1024**3) # 500 MiB
52 |
53 | if sock_type == zmq.PULL:
54 | s.setsockopt(zmq.RCVHWM, 0)
55 | s.setsockopt(zmq.RCVBUF, buf_size)
56 | s.connect(path)
57 | elif sock_type == zmq.PUSH:
58 | s.setsockopt(zmq.SNDHWM, 0)
59 | s.setsockopt(zmq.SNDBUF, buf_size)
60 | s.bind(path)
61 | else:
62 | raise ValueError(f"Unsupported socket type: {sock_type}")
63 |
64 | return s
65 |
66 |
67 | def get_open_zmq_ipc_path(description: str | None = None) -> str:
68 | """Get an open IPC path for ZMQ sockets.
69 |
70 | Args:
71 | description: An optional string description for where the socket is used.
72 | """
73 | filename = f"{description}-{uuid4()}" if description is not None else str(uuid4())
74 | return f"ipc://{TMP_DIR}/{filename}"
75 |
76 |
77 | @contextlib.contextmanager
78 | def zmq_sync_socket(path: str, sock_type: int) -> Iterator[zmq.Socket]:
79 | """Context manager that creates and cleans up a ZMQ socket."""
80 | ctx = zmq.Context(io_threads=2)
81 | try:
82 | yield make_zmq_socket(ctx, path, sock_type)
83 |
84 | finally:
85 | ctx.destroy(linger=0)
86 |
--------------------------------------------------------------------------------
/python/cornserve/tracing.py:
--------------------------------------------------------------------------------
1 | """OpenTelemetry configuration for the cornserve services."""
2 |
3 | from __future__ import annotations
4 |
5 | from opentelemetry import trace
6 | from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
7 | from opentelemetry.sdk.resources import Resource
8 | from opentelemetry.sdk.trace import TracerProvider
9 | from opentelemetry.sdk.trace.export import BatchSpanProcessor
10 |
11 | from cornserve.constants import K8S_OTEL_GRPC_URL
12 |
13 |
14 | def configure_otel(name: str) -> None:
15 | """Configure OpenTelemetry for the given service name."""
16 | resource = Resource.create({"service.name": name})
17 | provider = TracerProvider(resource=resource)
18 | exporter = OTLPSpanExporter(endpoint=K8S_OTEL_GRPC_URL)
19 | processor = BatchSpanProcessor(exporter)
20 | provider.add_span_processor(processor)
21 | trace.set_tracer_provider(provider)
22 |
--------------------------------------------------------------------------------
/python/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "cornserve"
7 | description = "Easy, fast, and scalable multimodal agentic AI"
8 | authors = [
9 | { name = "Cornserve Team" },
10 | ]
11 | readme = "README.md"
12 | license = { file = "LICENSE" }
13 | classifiers = [
14 | "Environment :: GPU :: NVIDIA CUDA",
15 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
16 | "Programming Language :: Python :: 3.11",
17 | "Programming Language :: Python :: 3.12",
18 | "Programming Language :: Python :: 3.13",
19 | ]
20 | requires-python = ">=3.11"
21 | dependencies = [
22 | "grpcio-tools==1.71.0",
23 | "rich",
24 | "requests",
25 | "tyro",
26 | "kubernetes_asyncio",
27 | "httpx",
28 | "pydantic>=2.11",
29 | "opentelemetry-api",
30 | "opentelemetry-sdk",
31 | "opentelemetry-exporter-otlp-proto-grpc",
32 | "websocket-client",
33 | ]
34 | dynamic = ["version"]
35 |
36 | [project.scripts]
37 | cornserve = "cornserve.cli:main"
38 |
39 | [project.optional-dependencies]
40 | sidecar-api = [
41 | "torch>=2.5.0",
42 | "opentelemetry-instrumentation-grpc",
43 | "opentelemetry-instrumentation-threading",
44 | ]
45 | sidecar = [
46 | "torch>=2.5.0",
47 | "ucxx-cu12",
48 | "msgspec",
49 | "opentelemetry-instrumentation-grpc",
50 | ]
51 | gateway = [
52 | "fastapi",
53 | "uvicorn[standard]",
54 | "opentelemetry-instrumentation-fastapi",
55 | "opentelemetry-instrumentation-grpc",
56 | "opentelemetry-instrumentation-httpx",
57 | "websocket-client",
58 | ]
59 | resource-manager = [
60 | "opentelemetry-instrumentation-grpc",
61 | ]
62 | task-manager = []
63 | task-dispatcher = [
64 | "fastapi",
65 | "uvicorn[standard]",
66 | "opentelemetry-instrumentation-fastapi",
67 | "opentelemetry-instrumentation-httpx",
68 | "opentelemetry-instrumentation-grpc",
69 | ]
70 | audio = ["librosa"]
71 | eric-no-gpu = [
72 | "fastapi",
73 | "uvicorn[standard]",
74 | "pyzmq",
75 | "msgspec",
76 | "psutil",
77 | "torch>=2.5.0",
78 | "transformers",
79 | "huggingface_hub",
80 | "pillow",
81 | "opencv-python-headless",
82 | "einops",
83 | "cornserve[sidecar-api,audio]",
84 | "opentelemetry-instrumentation-fastapi",
85 | "opentelemetry-instrumentation-threading",
86 | ]
87 | eric = ["flash-attn", "xformers", "cornserve[eric-no-gpu]"]
88 | eric-audio = ["cornserve[eric,audio]"]
89 | dev-common = [
90 | "grpcio-tools",
91 | "pyright!=1.1.401",
92 | "ruff",
93 | "pytest",
94 | "pytest-asyncio",
95 | "pytest-dependency",
96 | "cornserve[sidecar-api,gateway,resource-manager,task-manager,task-dispatcher]",
97 | ]
98 | dev = ["cornserve[dev-common,sidecar,eric]"]
99 | dev-no-gpu = ["cornserve[dev-common,eric-no-gpu]"]
100 |
101 | [tool.setuptools.dynamic]
102 | version = { attr = "cornserve.__version__" }
103 |
104 | [tool.ruff]
105 | line-length = 120
106 |
107 | [tool.ruff.lint]
108 | pydocstyle.convention = "google"
109 | select = [
110 | "E", # pycodestyle error
111 | "F", # pyflakes
112 | "D", # pydocstyle
113 | "PL", # pylint
114 | "N", # pep8-naming
115 | "UP", # pyupgrade
116 | "B", # flake8-bugbear (detects likely bugs)
117 | "G", # flake8-logging-format (complains about logging)
118 | "SIM", # flake8-simplify (suggests code simplifications)
119 | ]
120 | exclude = [
121 | "cornserve/task_executors/eric/models/*.py",
122 | ]
123 | ignore = [
124 | "PLW0603", # Global statement
125 | "PLR2004", # Magic value
126 | "PLR0912", # Too many branches
127 | "PLR0913", # Too many arguments
128 | "PLR0915", # Too many statements
129 | "PLR0402", # `import torch.nn as nn` is fine
130 | ]
131 |
132 | [tool.ruff.lint.per-file-ignores]
133 | "**/__init__.py" = ["F401", "F403"]
134 | "cornserve/services/**/server.py" = ["N802"]
135 | "cornserve/services/**/grpc.py" = ["N802"]
136 |
137 | [tool.pyright]
138 | exclude = [
139 | "**/*_pb2.py",
140 | "**/*_pb2_grpc.py",
141 | "cornserve/task_executors/eric/models/*.py",
142 | ]
143 |
144 | [tool.pytest.ini_options]
145 | addopts = "-v"
146 | asyncio_mode = "strict"
147 | asyncio_default_fixture_loop_scope = "function"
148 |
149 | [tool.ruff.lint.isort]
150 | known-first-party = ["cornserve"]
151 |
--------------------------------------------------------------------------------
/python/scripts/lint.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -ev
4 |
5 | echo ${BASH_SOURCE[0]}
6 |
7 | cd "$(dirname "${BASH_SOURCE[0]}")/.."
8 |
9 | if [[ -z $GITHUB_ACTION ]]; then
10 | ruff format --target-version py311 cornserve tests
11 | ruff check --fix-only --select I cornserve tests
12 | else
13 | ruff format --target-version py311 --check cornserve tests
14 | ruff check --select I cornserve tests
15 | fi
16 |
17 | ruff check --target-version py311 cornserve
18 | pyright cornserve
19 |
--------------------------------------------------------------------------------
/python/tests/services/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/services/__init__.py
--------------------------------------------------------------------------------
/python/tests/services/sidecar/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/services/sidecar/__init__.py
--------------------------------------------------------------------------------
/python/tests/services/sidecar/utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import multiprocessing
3 | import os
4 | import signal
5 | import time
6 |
7 | import grpc
8 | import pytest
9 | import torch
10 |
11 | from cornserve.services.pb import sidecar_pb2, sidecar_pb2_grpc
12 | from cornserve.sidecar.constants import grpc_url_from_rank
13 |
14 |
15 | def run_server(rank: int, world_size: int, local_peer_ranks: list[int], shm_size: int) -> None:
16 | """Sidecar server entrypoint that will run in a subprocess."""
17 | mock_device()
18 |
19 | # Set environment variables
20 | os.environ["SIDECAR_RANK"] = str(rank)
21 | os.environ["SIDECAR_WORLD_SIZE"] = str(world_size)
22 | os.environ["SIDECAR_LOCAL_PEER_RANKS"] = ",".join(map(str, local_peer_ranks))
23 | os.environ["SIDECAR_SHM_SIZE"] = str(shm_size)
24 | os.environ["SIDECAR_IS_LOCAL"] = "true"
25 |
26 | from cornserve.services.sidecar.server import cleanup_coroutines, main
27 |
28 | loop = asyncio.new_event_loop()
29 | asyncio.set_event_loop(loop)
30 |
31 | try:
32 | loop.run_until_complete(main())
33 | except KeyboardInterrupt:
34 | pass
35 | finally:
36 | loop.run_until_complete(loop.shutdown_asyncgens())
37 | loop.run_until_complete(asyncio.gather(*cleanup_coroutines))
38 | loop.close()
39 |
40 |
41 | def device_from_rank(rank: int) -> torch.device:
42 | """Get the device for a given rank."""
43 | if torch.cuda.is_available():
44 | return torch.device(f"cuda:{rank % torch.cuda.device_count()}")
45 | return torch.device("cpu")
46 |
47 |
48 | def mock_device() -> None:
49 | mocker = pytest.MonkeyPatch()
50 | mocker.setattr(
51 | "cornserve.sidecar.utils.device_from_rank",
52 | device_from_rank,
53 | )
54 |
55 |
56 | def start_sidecar_servers(
57 | n: int = 4,
58 | cluster_size: int = 2,
59 | shm_size: int = 2 << 28,
60 | ) -> list[multiprocessing.Process]:
61 | """Start n sidecar servers in n processes."""
62 | processes = []
63 | ctx = multiprocessing.get_context("spawn")
64 | for rank in range(n):
65 | cluster_start = (rank // cluster_size) * cluster_size
66 | cluster_ranks = list(range(cluster_start, cluster_start + cluster_size))
67 | print("Starting sidecar server of rank", rank, "with cluster ranks", cluster_ranks)
68 | process = ctx.Process(
69 | target=run_server,
70 | args=(rank, n, cluster_ranks, shm_size),
71 | )
72 | process.start()
73 | processes.append(process)
74 | return processes
75 |
76 |
77 | def server_is_online(stub: sidecar_pb2_grpc.SidecarStub) -> bool:
78 | """Check if the server is running."""
79 | try:
80 | req = sidecar_pb2.CheckHealthRequest()
81 | res = stub.CheckHealth(req)
82 | return res.status == sidecar_pb2.HealthStatus.HEALTH_ALL_GOOD
83 | except grpc.RpcError:
84 | return False
85 |
86 |
87 | def wait_for_servers_to_start(rank: int) -> None:
88 | while True:
89 | with grpc.insecure_channel(grpc_url_from_rank(rank)) as channel:
90 | stub = sidecar_pb2_grpc.SidecarStub(channel)
91 | if server_is_online(stub):
92 | break
93 | else:
94 | time.sleep(10.2)
95 |
96 |
97 | def terminate_processes(processes: list[multiprocessing.Process]) -> None:
98 | """Terminate all processes."""
99 |
100 | for process in processes:
101 | if process.pid:
102 | os.kill(process.pid, signal.SIGINT)
103 |
104 | for process in processes:
105 | process.join(timeout=5)
106 | if not process.is_alive():
107 | continue
108 | process.terminate()
109 | process.join(timeout=2)
110 | if process.is_alive():
111 | process.kill()
112 | process.join()
113 |
--------------------------------------------------------------------------------
/python/tests/task/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task/__init__.py
--------------------------------------------------------------------------------
/python/tests/task/builtins/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task/builtins/__init__.py
--------------------------------------------------------------------------------
/python/tests/task/builtins/test_mllm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import asyncio
4 |
5 | import pytest
6 |
7 | from cornserve.task.base import TaskContext, TaskInvocation, task_context
8 | from cornserve.task.builtins.mllm import MLLMInput, MLLMTask, Modality
9 | from cornserve.task.forward import DataForward, ForwardableType, Tensor
10 |
11 |
12 | def test_mllm_record():
13 | """Test MLLM task invocation recording."""
14 | task = MLLMTask(model_id="llava", modalities=[Modality.IMAGE])
15 | task_input = MLLMInput(prompt="Hello, world!", multimodal_data=[("image", "http://example.com/image.jpg")])
16 |
17 | ctx = TaskContext(task_id="mllm-test")
18 | task_context.set(ctx)
19 | with ctx.record():
20 | task_output = task.invoke(task_input)
21 |
22 | assert isinstance(task_output.response, str)
23 | assert task_output.response == ""
24 |
25 | assert len(ctx.invocations) == 2
26 | assert ctx.invocations[0].task == task.image_encoder
27 | assert ctx.invocations[0].task_input.data_urls == ["http://example.com/image.jpg"]
28 | assert len(ctx.invocations[0].task_output.embeddings) == 1
29 | assert (
30 | ctx.invocations[0].task_output.embeddings[0].data_type
31 | == DataForward[Tensor]().data_type
32 | == ForwardableType.TENSOR
33 | )
34 | assert ctx.invocations[1].task_input.prompt == "Hello, world!"
35 | assert ctx.invocations[0].task_output.embeddings[0] == ctx.invocations[1].task_input.embeddings[0]
36 |
37 |
38 | @pytest.mark.asyncio
39 | async def test_mllm_record_concurrent():
40 | """Test multiple concurrent MLLM task invocations."""
41 |
42 | task = MLLMTask(model_id="llava", modalities=[Modality.IMAGE, Modality.VIDEO])
43 | task_input = MLLMInput(
44 | prompt="Hello, world!",
45 | multimodal_data=[("image", "http://example.com/image.jpg"), ("video", "http://example.com/video.mp4")],
46 | )
47 |
48 | async def call(task: MLLMTask, task_input: MLLMInput) -> list[TaskInvocation]:
49 | task_context.set(TaskContext(task_id=task.id))
50 | return await asyncio.create_task(call_impl(task, task_input))
51 |
52 | async def call_impl(task: MLLMTask, task_input: MLLMInput) -> list[TaskInvocation]:
53 | ctx = task_context.get()
54 |
55 | with ctx.record():
56 | _ = task.invoke(task_input)
57 |
58 | return ctx.invocations
59 |
60 | invocations1, invocations2 = await asyncio.gather(
61 | call(task, task_input),
62 | call(task, task_input),
63 | )
64 |
65 | assert len(invocations1) == 3
66 | assert len(invocations2) == 3
67 |
68 | assert invocations1[0].task == task.image_encoder
69 | assert invocations1[0].task_input.data_urls == ["http://example.com/image.jpg"]
70 | assert invocations1[1].task == task.video_encoder
71 | assert invocations1[1].task_input.data_urls == ["http://example.com/video.mp4"]
72 | assert invocations1[2].task == task.llm
73 |
74 | assert invocations2[0].task == task.image_encoder
75 | assert invocations2[0].task_input.data_urls == ["http://example.com/image.jpg"]
76 | assert invocations2[1].task == task.video_encoder
77 | assert invocations2[1].task_input.data_urls == ["http://example.com/video.mp4"]
78 | assert invocations2[2].task == task.llm
79 |
--------------------------------------------------------------------------------
/python/tests/task/test_base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from cornserve.task.base import TaskGraphDispatch, TaskInvocation
4 | from cornserve.task.builtins.encoder import EncoderInput, EncoderOutput, EncoderTask, Modality
5 | from cornserve.task.builtins.llm import LLMBaseTask, LLMForwardOutputTask, LLMInput, LLMOutput, LLMTask
6 | from cornserve.task.forward import DataForward, Tensor
7 |
8 |
9 | def test_root_unit_task_cls():
10 | """Tests whether the root unit task class is figured out correctly."""
11 | assert LLMTask.root_unit_task_cls is LLMBaseTask
12 | assert LLMForwardOutputTask.root_unit_task_cls is LLMBaseTask
13 | assert EncoderTask.root_unit_task_cls is EncoderTask
14 |
15 |
16 | def test_serde_one():
17 | """Tests whether unit tasks can be serialized and deserialized."""
18 | invocation = TaskInvocation(
19 | task=LLMTask(model_id="llama"),
20 | task_input=LLMInput(prompt="Hello", multimodal_data=[]),
21 | task_output=LLMOutput(response="Hello"),
22 | )
23 | invocation_json = invocation.model_dump_json()
24 |
25 | invocation_deserialized = TaskInvocation.model_validate_json(invocation_json)
26 | assert invocation == invocation_deserialized
27 |
28 |
29 | def test_serde_graph():
30 | """Tests whether task graph invocations can be serialized and deserialized."""
31 | encoder_invocation = TaskInvocation(
32 | task=EncoderTask(model_id="clip", modality=Modality.IMAGE),
33 | task_input=EncoderInput(data_urls=["https://example.com/image.jpg"]),
34 | task_output=EncoderOutput(embeddings=[DataForward[Tensor]()]),
35 | )
36 | llm_invocation = TaskInvocation(
37 | task=LLMTask(model_id="llama"),
38 | task_input=LLMInput(prompt="Hello", multimodal_data=[("image", "https://example.com/image.jpg")]),
39 | task_output=LLMOutput(response="Hello"),
40 | )
41 | graph = TaskGraphDispatch(
42 | task_id="test-graph",
43 | invocations=[encoder_invocation, llm_invocation],
44 | )
45 | graph_json = graph.model_dump_json()
46 |
47 | graph_deserialized = TaskGraphDispatch.model_validate_json(graph_json)
48 | assert graph == graph_deserialized
49 |
50 |
51 | def test_task_equivalence():
52 | """Tests whether unit task equivalence is determined correctly."""
53 | assert LLMTask(model_id="llama").is_equivalent_to(LLMTask(model_id="llama"))
54 | assert not LLMTask(model_id="llama").is_equivalent_to(LLMTask(model_id="mistral"))
55 | assert EncoderTask(model_id="clip", modality=Modality.IMAGE).is_equivalent_to(
56 | EncoderTask(model_id="clip", modality=Modality.IMAGE)
57 | )
58 | assert not EncoderTask(model_id="clip", modality=Modality.IMAGE).is_equivalent_to(
59 | EncoderTask(model_id="clip", modality=Modality.VIDEO)
60 | )
61 |
--------------------------------------------------------------------------------
/python/tests/task/test_registry.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 |
5 | from cornserve.task.registry import TASK_REGISTRY
6 |
7 |
8 | def test_task_registry():
9 | """Tests whether the task registry is initialized correctly."""
10 | llm_task = TASK_REGISTRY.get("LLMTask")
11 | llm_forward_output_task = TASK_REGISTRY.get("LLMForwardOutputTask")
12 | encoder_task = TASK_REGISTRY.get("EncoderTask")
13 |
14 | from cornserve.task.builtins.encoder import EncoderInput, EncoderOutput, EncoderTask
15 | from cornserve.task.builtins.llm import LLMForwardOutput, LLMForwardOutputTask, LLMInput, LLMOutput, LLMTask
16 |
17 | assert llm_task == (LLMTask, LLMInput, LLMOutput)
18 | assert llm_forward_output_task == (LLMForwardOutputTask, LLMInput, LLMForwardOutput)
19 | assert encoder_task == (EncoderTask, EncoderInput, EncoderOutput)
20 |
21 | assert "_NonExistentTask" not in TASK_REGISTRY
22 | with pytest.raises(KeyError):
23 | TASK_REGISTRY.get("_NonEistentTask")
24 |
--------------------------------------------------------------------------------
/python/tests/task_executors/eric/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task_executors/eric/__init__.py
--------------------------------------------------------------------------------
/python/tests/task_executors/eric/engine/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task_executors/eric/engine/__init__.py
--------------------------------------------------------------------------------
/python/tests/task_executors/eric/engine/test_scheduler.py:
--------------------------------------------------------------------------------
1 | from cornserve.task_executors.eric.engine.scheduler import Scheduler
2 | from cornserve.task_executors.eric.schema import EngineEnqueueRequest, Modality, ProcessedEmbeddingData
3 |
4 |
5 | def test_mixed_modality():
6 | """Batches should only hvae a single modality."""
7 | scheduler = Scheduler()
8 |
9 | scheduler.enqueue(
10 | EngineEnqueueRequest(
11 | request_id="1",
12 | data=[
13 | ProcessedEmbeddingData(id="im1", modality=Modality.IMAGE, data={}),
14 | ProcessedEmbeddingData(id="im2", modality=Modality.IMAGE, data={}),
15 | ],
16 | )
17 | )
18 | scheduler.enqueue(
19 | EngineEnqueueRequest(
20 | request_id="2",
21 | data=[
22 | ProcessedEmbeddingData(id="vid1", modality=Modality.VIDEO, data={}),
23 | ProcessedEmbeddingData(id="im3", modality=Modality.IMAGE, data={}),
24 | ],
25 | )
26 | )
27 | scheduler.enqueue(
28 | EngineEnqueueRequest(
29 | request_id="3",
30 | data=[
31 | ProcessedEmbeddingData(id="vid2", modality=Modality.VIDEO, data={}),
32 | ProcessedEmbeddingData(id="vid3", modality=Modality.VIDEO, data={}),
33 | ],
34 | )
35 | )
36 |
37 | assert scheduler.has_waiting_requests()
38 | batch = scheduler.schedule()
39 | assert batch.modality == Modality.IMAGE
40 | assert len(batch.request_ids) == 2
41 | scheduler.process_batch_result(request_ids=["1", "1"], data_ids=["im1", "im2"])
42 |
43 | assert scheduler.has_waiting_requests()
44 | batch = scheduler.schedule()
45 | assert batch.modality == Modality.VIDEO
46 | assert len(batch.request_ids) == 1
47 | scheduler.process_batch_result(request_ids=["2"], data_ids=["vid1"])
48 |
49 | assert scheduler.has_waiting_requests()
50 | batch = scheduler.schedule()
51 | assert batch.modality == Modality.IMAGE
52 | assert len(batch.request_ids) == 1
53 | scheduler.process_batch_result(request_ids=["2"], data_ids=["im3"])
54 |
55 | assert scheduler.has_waiting_requests()
56 | batch = scheduler.schedule()
57 | assert batch.modality == Modality.VIDEO
58 | assert len(batch.request_ids) == 2
59 | scheduler.process_batch_result(request_ids=["3", "3"], data_ids=["vid2", "vid3"])
60 |
61 | assert not scheduler.has_waiting_requests()
62 |
--------------------------------------------------------------------------------
/python/tests/task_executors/eric/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task_executors/eric/models/__init__.py
--------------------------------------------------------------------------------
/python/tests/task_executors/eric/models/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tempfile
4 | from collections.abc import Generator
5 |
6 | import pytest
7 |
8 | from cornserve.task_executors.eric.schema import Modality
9 |
10 | from ..utils import ModalityData
11 |
12 | TEST_IMAGE_URLS = [
13 | "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
14 | "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
15 | ]
16 |
17 |
18 | @pytest.fixture(scope="session")
19 | def test_images() -> list[ModalityData]:
20 | """Fixture to provide test images."""
21 | return [ModalityData(url=url, modality=Modality.IMAGE) for url in TEST_IMAGE_URLS]
22 |
23 |
24 | TEST_VIDEO_URLS = [
25 | "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4",
26 | "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ElephantsDream.mp4",
27 | "https://www.sample-videos.com/video321/mp4/360/big_buck_bunny_360p_2mb.mp4",
28 | "https://www.sample-videos.com/video321/mp4/720/big_buck_bunny_720p_2mb.mp4",
29 | ]
30 |
31 |
32 | @pytest.fixture(scope="session")
33 | def test_videos() -> list[ModalityData]:
34 | """Fixture to provide test videos."""
35 | return [ModalityData(url=url, modality=Modality.VIDEO) for url in TEST_VIDEO_URLS]
36 |
37 |
38 | TEST_AUDIO_URLS = [
39 | "https://s3.amazonaws.com/citizen-dj-assets.labs.loc.gov/audio/samplepacks/loc-fma/The-Call-of-the-Polar-Star_fma-115766_001_00-00-01.wav",
40 | "https://s3.amazonaws.com/citizen-dj-assets.labs.loc.gov/audio/samplepacks/loc-fma/Frog-In-The-Well_fma-39182_001_00-00-06.wav",
41 | "https://s3.amazonaws.com/citizen-dj-assets.labs.loc.gov/audio/samplepacks/loc-fma/Free-To-Use-13_fma-152622_004_00-02-55.wav",
42 | ]
43 |
44 |
45 | @pytest.fixture(scope="session")
46 | def test_audios() -> list[ModalityData]:
47 | """Fixture to provide test audios."""
48 | return [ModalityData(url=url, modality=Modality.AUDIO) for url in TEST_AUDIO_URLS]
49 |
50 |
51 | @pytest.fixture(scope="session")
52 | def dump_tensors() -> Generator[str, None, None]:
53 | """Fixture to set `CORNSERVE_TEST_DUMP_TENSOR_DIR` environment variable."""
54 | dir = os.getenv("CORNSERVE_TEST_DUMP_TENSOR_DIR")
55 |
56 | # If unset, do it in a tempdir and clean up after the test session.
57 | if dir is None:
58 | tmp = os.environ["CORNSERVE_TEST_DUMP_TENSOR_DIR"] = tempfile.mkdtemp()
59 | yield tmp
60 | del os.environ["CORNSERVE_TEST_DUMP_TENSOR_DIR"]
61 | shutil.rmtree(tmp)
62 |
63 | # If explicitly set, use it and leave the stuff because the user would have intended to keep it.
64 | else:
65 | yield dir
66 |
--------------------------------------------------------------------------------
/python/tests/task_executors/eric/models/test_gemma3.py:
--------------------------------------------------------------------------------
1 | """Tests for the Gemma3 model's vision encoder."""
2 |
3 | import pytest
4 | import torch
5 | from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
6 |
7 | from cornserve.task_executors.eric.distributed.parallel import destroy_distributed, init_distributed
8 | from cornserve.task_executors.eric.executor.executor import ModelExecutor
9 | from cornserve.task_executors.eric.executor.loader import load_model
10 | from cornserve.task_executors.eric.models.registry import MODEL_REGISTRY
11 | from cornserve.task_executors.eric.schema import Status
12 |
13 | from ..utils import (
14 | TP_SIZES,
15 | ModalityData,
16 | assert_same_weights,
17 | assert_similar,
18 | batch_builder,
19 | depends_on,
20 | param_tp_size,
21 | )
22 |
23 | model_id = "google/gemma-3-4b-it"
24 | model_shorthand = "gemma3"
25 |
26 |
27 | def test_weight_loading() -> None:
28 | """Check if weights are loaded correctly."""
29 | # Hugging Face model output
30 | hf_model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto").model
31 |
32 | # Load our model
33 | init_distributed(world_size=1, rank=0)
34 | our_model = load_model(model_id, torch_device=torch.device("cpu"))
35 | destroy_distributed()
36 |
37 | def check_qkv_proj_weight(
38 | our_name: str,
39 | our_param: torch.Tensor,
40 | hf_params: dict[str, torch.Tensor],
41 | ):
42 | """Check if the qkv_proj weights are the same."""
43 | separate_weights = []
44 | for key in ["q_proj", "k_proj", "v_proj"]:
45 | separate_weights.append(hf_params[our_name.replace("qkv_proj", key)])
46 | assert torch.allclose(our_param, torch.cat(separate_weights, dim=0))
47 |
48 | # Check if parameters are the same
49 | assert_same_weights(
50 | hf_model,
51 | our_model,
52 | required_prefixes=MODEL_REGISTRY[hf_model.config.model_type].weight.required_prefixes,
53 | ignored_prefixes=MODEL_REGISTRY[hf_model.config.model_type].weight.ignored_prefixes,
54 | transformed_weights={
55 | "*qkv_proj.weight": check_qkv_proj_weight,
56 | "*qkv_proj.bias": check_qkv_proj_weight,
57 | },
58 | )
59 |
60 |
61 | @param_tp_size
62 | def test_image_inference(test_images: list[ModalityData], tp_size: int, dump_tensors: str) -> None:
63 | """Test if inference works correctly."""
64 | executor = ModelExecutor(model_id=model_id, tp_size=tp_size, sender_sidecar_ranks=None)
65 |
66 | result = executor.execute_model(batch=batch_builder(model_id, model_shorthand, test_images))
67 |
68 | assert result.status == Status.SUCCESS
69 |
70 | executor.shutdown()
71 |
72 |
73 | @depends_on("test_image_inference")
74 | def test_hf_reference(test_images: list[ModalityData], dump_tensors: str) -> None:
75 | """Generate reference outputs from the Hugging Face model."""
76 | torch.set_grad_enabled(False)
77 |
78 | hf_model = Gemma3ForConditionalGeneration.from_pretrained(
79 | model_id,
80 | torch_dtype="auto",
81 | attn_implementation="flash_attention_2",
82 | )
83 | model = hf_model.model.cuda().eval()
84 |
85 | image1 = test_images[0].processed(model_id)
86 | pixel_values = torch.asarray(image1["pixel_values"]).cuda()
87 | output1 = model.get_image_features(pixel_values=pixel_values).cpu()
88 |
89 | image2 = test_images[1].processed(model_id)
90 | pixel_values = torch.asarray(image2["pixel_values"]).cuda()
91 | output2 = model.get_image_features(pixel_values=pixel_values).cpu()
92 |
93 | for tp_degree in TP_SIZES:
94 | output = torch.load(f"{dump_tensors}/{model_shorthand}-image-tp{tp_degree}.pt")
95 | assert_similar([output1, output2], output)
96 |
97 | del output1, output2
98 |
99 |
100 | pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning")
101 |
--------------------------------------------------------------------------------
/python/tests/task_executors/eric/models/test_qwen2_vl.py:
--------------------------------------------------------------------------------
1 | """Tests for the Qwen2-VL model's vision encoder."""
2 |
3 | import torch
4 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
5 |
6 | from cornserve.task_executors.eric.distributed.parallel import destroy_distributed, init_distributed
7 | from cornserve.task_executors.eric.executor.executor import ModelExecutor
8 | from cornserve.task_executors.eric.executor.loader import load_model
9 | from cornserve.task_executors.eric.schema import Status
10 |
11 | from ..utils import (
12 | TP_SIZES,
13 | ModalityData,
14 | assert_same_weights,
15 | assert_similar,
16 | batch_builder,
17 | depends_on,
18 | param_tp_size,
19 | )
20 |
21 | model_id = "Qwen/Qwen2-VL-7B-Instruct"
22 | model_shorthand = "qwen2"
23 |
24 |
25 | def test_weight_loading() -> None:
26 | """Check if weights are loaded correctly."""
27 | # Hugging Face model output
28 | hf_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto").visual
29 |
30 | # Load our model
31 | init_distributed(world_size=1, rank=0)
32 | our_model = load_model(model_id, torch_device=torch.device("cpu"))
33 | destroy_distributed()
34 |
35 | # Check if parameters are the same
36 | assert_same_weights(hf_model, our_model)
37 |
38 |
39 | @param_tp_size
40 | def test_image_inference(test_images: list[ModalityData], tp_size: int, dump_tensors: str) -> None:
41 | """Test if inference works correctly."""
42 | executor = ModelExecutor(model_id=model_id, tp_size=tp_size, sender_sidecar_ranks=None)
43 |
44 | result = executor.execute_model(batch=batch_builder(model_id, model_shorthand, test_images))
45 |
46 | assert result.status == Status.SUCCESS
47 |
48 | executor.shutdown()
49 |
50 |
51 | @param_tp_size
52 | def test_video_inference(test_videos: list[ModalityData], tp_size: int, dump_tensors: str) -> None:
53 | """Test if inference works correctly."""
54 | executor = ModelExecutor(model_id=model_id, tp_size=tp_size, sender_sidecar_ranks=None)
55 |
56 | result = executor.execute_model(batch=batch_builder(model_id, model_shorthand, test_videos[:2]))
57 |
58 | assert result.status == Status.SUCCESS
59 |
60 | executor.shutdown()
61 |
62 |
63 | @depends_on("test_image_inference", "test_video_inference")
64 | def test_hf_reference(test_images: list[ModalityData], test_videos: list[ModalityData], dump_tensors: str) -> None:
65 | """Generate reference outputs from the Hugging Face model."""
66 | torch.set_grad_enabled(False)
67 |
68 | hf_model = Qwen2VLForConditionalGeneration.from_pretrained(
69 | model_id,
70 | torch_dtype="auto",
71 | attn_implementation="flash_attention_2",
72 | )
73 | model = hf_model.model.cuda().eval()
74 |
75 | image1 = test_images[0].processed(model_id)
76 | pixel_values = torch.asarray(image1["pixel_values"]).cuda()
77 | image_grid_thw = torch.asarray(image1["image_grid_thw"]).cuda()
78 | output1 = model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw).cpu()
79 |
80 | image2 = test_images[1].processed(model_id)
81 | pixel_values = torch.asarray(image2["pixel_values"]).cuda()
82 | image_grid_thw = torch.asarray(image2["image_grid_thw"]).cuda()
83 | output2 = model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw).cpu()
84 |
85 | for tp_degree in TP_SIZES:
86 | output = torch.load(f"{dump_tensors}/{model_shorthand}-image-tp{tp_degree}.pt")
87 | assert_similar([output1, output2], output)
88 |
89 | del output1, output2
90 |
91 | video1 = test_videos[0].processed(model_id)
92 | pixel_values_video = torch.asarray(video1["pixel_values_videos"]).cuda()
93 | video_grid_thw = torch.asarray(video1["video_grid_thw"]).cuda()
94 | output1 = model.get_video_features(pixel_values_videos=pixel_values_video, video_grid_thw=video_grid_thw).cpu()
95 |
96 | video2 = test_videos[1].processed(model_id)
97 | pixel_values_video2 = torch.asarray(video2["pixel_values_videos"]).cuda()
98 | video_grid_thw2 = torch.asarray(video2["video_grid_thw"]).cuda()
99 | output2 = model.get_video_features(pixel_values_videos=pixel_values_video2, video_grid_thw=video_grid_thw2).cpu()
100 |
101 | for tp_degree in TP_SIZES:
102 | output = torch.load(f"{dump_tensors}/{model_shorthand}-video-tp{tp_degree}.pt")
103 | assert_similar([output1, output2], output)
104 |
105 | del output1, output2
106 |
--------------------------------------------------------------------------------
/scripts/generate_pb.sh:
--------------------------------------------------------------------------------
1 | #!/usr/local/env bash
2 |
3 | set -evo pipefail
4 |
5 | PROTO_DIR="proto/v1"
6 | PROTO_FILES=$(find $PROTO_DIR -name "*.proto")
7 |
8 | PYTHON_OUTPUT_DIR="python/cornserve/services/pb"
9 | mkdir -p "$PYTHON_OUTPUT_DIR"
10 | rm "$PYTHON_OUTPUT_DIR"/*pb2.py "$PYTHON_OUTPUT_DIR"/*pb2.pyi "$PYTHON_OUTPUT_DIR"/*pb2_grpc.py || true
11 | for proto in $PROTO_FILES; do
12 | echo "Generating code for $proto in $PYTHON_OUTPUT_DIR"
13 | python -m grpc_tools.protoc \
14 | -I$PROTO_DIR \
15 | --python_out=$PYTHON_OUTPUT_DIR \
16 | --grpc_python_out=$PYTHON_OUTPUT_DIR \
17 | --pyi_out=$PYTHON_OUTPUT_DIR \
18 | $proto
19 | done
20 |
21 | # The generated `import common_pb2`, for example, doesn't work.
22 | # We need to change it manually to `from . import common_pb2`.
23 | # `-i.bak` is done for GNU and BSD sed compatibility.
24 | find "$PYTHON_OUTPUT_DIR" -type f -name "*.py" -exec sed -i.bak -e 's/^\(import.*pb2\)/from . \1/g' {} \;
25 | find "$PYTHON_OUTPUT_DIR" -type f -name "*.bak" -exec rm {} \;
26 |
--------------------------------------------------------------------------------