├── .dockerignore ├── .github └── workflows │ ├── check_homepage_build.yaml │ ├── deploy_homepage.yaml │ ├── publish_pypi.yaml │ ├── push_docker.yaml │ └── python_lint.yaml ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── docker ├── dev.Dockerfile ├── eric.Dockerfile ├── gateway.Dockerfile ├── resource-manager.Dockerfile ├── sidecar.Dockerfile ├── task-dispatcher.Dockerfile ├── task-manager.Dockerfile └── vllm.Dockerfile ├── docs ├── CNAME ├── architecture │ ├── eric.md │ ├── index.md │ ├── sidecar.md │ └── task.md ├── assets │ ├── css │ │ └── extra.css │ ├── img │ │ └── favicon.png │ ├── js │ │ └── mathjax.js │ └── video │ │ └── cornserve.mp4 ├── contributor_guide │ ├── eric.md │ ├── index.md │ ├── kubernetes.md │ ├── sidecar.md │ └── tracing.md ├── getting_started │ ├── building_apps.md │ ├── cornserve.md │ ├── index.md │ ├── jupyter.ipynb │ └── registering_apps.md ├── index.md └── requirements.txt ├── examples ├── mllm │ ├── app.py │ └── requirements.txt └── notebook.ipynb ├── kubernetes ├── README.md ├── k3s │ ├── agent-config.yaml │ ├── registries.yaml │ └── server-config.yaml ├── kustomize │ ├── cornserve-system │ │ ├── base │ │ │ ├── jaeger │ │ │ │ ├── configmap.yaml │ │ │ │ ├── deployment.yaml │ │ │ │ ├── kustomization.yaml │ │ │ │ └── service.yaml │ │ │ ├── kustomization.yaml │ │ │ └── namespace.yaml │ │ └── overlays │ │ │ ├── dev │ │ │ ├── jaeger │ │ │ │ ├── kustomization.yaml │ │ │ │ ├── node-port-service.yaml │ │ │ │ └── volume-patch.yaml │ │ │ └── kustomization.yaml │ │ │ ├── local │ │ │ ├── jaeger │ │ │ │ ├── kustomization.yaml │ │ │ │ ├── node-port-service.yaml │ │ │ │ └── volume-patch.yaml │ │ │ └── kustomization.yaml │ │ │ └── minikube │ │ │ ├── jaeger │ │ │ ├── kustomization.yaml │ │ │ └── node-port-service.yaml │ │ │ └── kustomization.yaml │ └── cornserve │ │ ├── base │ │ ├── gateway │ │ │ ├── deployment.yaml │ │ │ ├── kustomization.yaml │ │ │ └── service.yaml │ │ ├── kustomization.yaml │ │ ├── namespace.yaml │ │ ├── resource-manager │ │ │ ├── deployment.yaml │ │ │ ├── kustomization.yaml │ │ │ ├── role-binding.yaml │ │ │ ├── role.yaml │ │ │ ├── service-account.yaml │ │ │ └── service.yaml │ │ ├── sidecar │ │ │ ├── kustomization.yaml │ │ │ ├── role-binding.yaml │ │ │ ├── role.yaml │ │ │ ├── service-account.yaml │ │ │ ├── service.yaml │ │ │ └── statefulset.yaml │ │ ├── task-dispatcher │ │ │ ├── deployment.yaml │ │ │ ├── kustomization.yaml │ │ │ └── service.yaml │ │ └── task-manager │ │ │ ├── kustomization.yaml │ │ │ ├── role-binding.yaml │ │ │ ├── role.yaml │ │ │ └── service-account.yaml │ │ └── overlays │ │ ├── dev │ │ ├── gateway │ │ │ ├── kustomization.yaml │ │ │ └── node-port-service.yaml │ │ ├── image-pull-policy.yaml │ │ └── kustomization.yaml │ │ ├── local │ │ ├── gateway │ │ │ ├── kustomization.yaml │ │ │ └── node-port-service.yaml │ │ ├── image-pull-policy.yaml │ │ └── kustomization.yaml │ │ ├── minikube │ │ ├── gateway │ │ │ ├── kustomization.yaml │ │ │ └── node-port-service.yaml │ │ └── kustomization.yaml │ │ └── prod │ │ └── kustomization.yaml ├── registry.sh └── set_registry.sh ├── mkdocs.yml ├── proto └── v1 │ ├── common.proto │ ├── resource_manager.proto │ ├── sidecar.proto │ ├── task_dispatcher.proto │ └── task_manager.proto ├── python ├── cornserve │ ├── __init__.py │ ├── app │ │ ├── __init__.py │ │ └── base.py │ ├── cli.py │ ├── constants.py │ ├── frontend.py │ ├── logging.py │ ├── services │ │ ├── __init__.py │ │ ├── gateway │ │ │ ├── __init__.py │ │ │ ├── app │ │ │ │ ├── manager.py │ │ │ │ └── models.py │ │ │ ├── entrypoint.py │ │ │ ├── models.py │ │ │ ├── router.py │ │ │ ├── session.py │ │ │ └── task_manager.py │ │ ├── resource_manager │ │ │ ├── __init__.py │ │ │ ├── entrypoint.py │ │ │ ├── grpc.py │ │ │ ├── manager.py │ │ │ └── resource.py │ │ ├── sidecar │ │ │ ├── __init__.py │ │ │ ├── launch.py │ │ │ ├── receiver.py │ │ │ ├── scheduler.py │ │ │ ├── schema.py │ │ │ ├── sender.py │ │ │ ├── server.py │ │ │ └── shm_manager.py │ │ ├── task_dispatcher │ │ │ ├── __init__.py │ │ │ ├── dispatcher.py │ │ │ ├── entrypoint.py │ │ │ ├── grpc.py │ │ │ └── router.py │ │ └── task_manager │ │ │ ├── __init__.py │ │ │ ├── entrypoint.py │ │ │ ├── grpc.py │ │ │ └── manager.py │ ├── sidecar │ │ ├── __init__.py │ │ ├── api.py │ │ ├── constants.py │ │ ├── schema.py │ │ ├── serde.py │ │ └── utils.py │ ├── task │ │ ├── __init__.py │ │ ├── base.py │ │ ├── builtins │ │ │ ├── __init__.py │ │ │ ├── encoder.py │ │ │ ├── llm.py │ │ │ └── mllm.py │ │ ├── forward.py │ │ └── registry.py │ ├── task_executors │ │ ├── __init__.py │ │ ├── descriptor │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── builtins │ │ │ │ ├── __init__.py │ │ │ │ ├── encoder.py │ │ │ │ └── llm.py │ │ │ └── registry.py │ │ └── eric │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── config.py │ │ │ ├── distributed │ │ │ ├── parallel.py │ │ │ └── shm_broadcast.py │ │ │ ├── engine │ │ │ ├── __init__.py │ │ │ ├── client.py │ │ │ ├── core.py │ │ │ └── scheduler.py │ │ │ ├── entrypoint.py │ │ │ ├── executor │ │ │ ├── __init__.py │ │ │ ├── executor.py │ │ │ ├── loader.py │ │ │ └── worker.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── gemma3.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── activations.py │ │ │ │ ├── attention.py │ │ │ │ ├── linear.py │ │ │ │ ├── norm.py │ │ │ │ ├── siglip.py │ │ │ │ └── vit.py │ │ │ ├── llava_onevision.py │ │ │ ├── qwen2_5_omni.py │ │ │ ├── qwen2_5_vl.py │ │ │ ├── qwen2_vl.py │ │ │ └── registry.py │ │ │ ├── router │ │ │ ├── __init__.py │ │ │ ├── app.py │ │ │ └── processor.py │ │ │ ├── schema.py │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── distributed.py │ │ │ ├── network.py │ │ │ ├── package.py │ │ │ ├── process.py │ │ │ ├── serde.py │ │ │ └── zmq.py │ └── tracing.py ├── pyproject.toml ├── scripts │ └── lint.sh └── tests │ ├── services │ ├── __init__.py │ └── sidecar │ │ ├── __init__.py │ │ ├── test_sidecar.py │ │ └── utils.py │ ├── task │ ├── __init__.py │ ├── builtins │ │ ├── __init__.py │ │ └── test_mllm.py │ ├── test_base.py │ └── test_registry.py │ └── task_executors │ └── eric │ ├── __init__.py │ ├── engine │ ├── __init__.py │ └── test_scheduler.py │ ├── models │ ├── __init__.py │ ├── conftest.py │ ├── test_gemma3.py │ ├── test_llava_onevision.py │ ├── test_qwen2_5_omni.py │ ├── test_qwen2_5_vl.py │ └── test_qwen2_vl.py │ └── utils.py └── scripts ├── build_export_images.sh └── generate_pb.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | # Cornserve repo-specific 2 | docs/ 3 | examples/ 4 | scratchpad/ 5 | scripts/ 6 | 7 | # Git 8 | .git 9 | .gitignore 10 | .gitattributes 11 | 12 | 13 | # CI 14 | .codeclimate.yml 15 | .travis.yml 16 | .taskcluster.yml 17 | 18 | # Docker and Kubernetes 19 | docker-compose.yml 20 | Dockerfile 21 | *.Dockerfile 22 | .docker 23 | .dockerignore 24 | kubernetes/ 25 | docker/ 26 | 27 | # Byte-compiled / optimized / DLL / type stub files 28 | **/__pycache__/ 29 | **/*.py[codi] 30 | 31 | # C extensions 32 | *.so 33 | 34 | # Distribution / packaging 35 | .Python 36 | env/ 37 | build/ 38 | develop-eggs/ 39 | dist/ 40 | downloads/ 41 | eggs/ 42 | lib/ 43 | lib64/ 44 | parts/ 45 | sdist/ 46 | var/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .coverage 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Virtual environment 83 | .env 84 | .envrc 85 | .venv/ 86 | venv/ 87 | 88 | # PyCharm 89 | .idea 90 | 91 | # Python mode for VIM 92 | .ropeproject 93 | **/.ropeproject 94 | 95 | # Vim swap files 96 | **/*.swp 97 | 98 | # VS Code 99 | .vscode/ 100 | 101 | # Protobuf 102 | proto/ 103 | *.proto 104 | 105 | # Text files 106 | *.md 107 | -------------------------------------------------------------------------------- /.github/workflows/check_homepage_build.yaml: -------------------------------------------------------------------------------- 1 | name: Check homepage build 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - 'examples/**' 7 | - 'docs/**' 8 | - 'mkdocs.yml' 9 | - '.github/workflows/check_homepage_build.yaml' 10 | 11 | concurrency: 12 | group: ${{ github.ref }}-check-homepage-build 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | check: 17 | runs-on: ubuntu-latest 18 | if: github.event.repository.fork == false 19 | steps: 20 | - name: Checkout repository 21 | uses: actions/checkout@v4 22 | - name: Setup Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: 3.11 26 | cache: 'pip' 27 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 28 | - uses: actions/cache@v4 29 | with: 30 | key: mkdocs-material-${{ env.cache_id }} 31 | path: .cache 32 | restore-keys: | 33 | mkdocs-material- 34 | - name: Install homepage build dependencies 35 | run: pip install -r docs/requirements.txt 36 | - name: Build homepage 37 | run: mkdocs build --verbose --strict 38 | env: 39 | BUILD_SOCIAL_CARD: true 40 | -------------------------------------------------------------------------------- /.github/workflows/deploy_homepage.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy homepage 2 | on: 3 | push: 4 | branches: 5 | - master 6 | paths: 7 | - 'examples/**' 8 | - 'docs/**' 9 | - 'mkdocs.yml' 10 | - '.github/workflows/deploy_homepage.yaml' 11 | 12 | env: 13 | SITE_ANALYTICS: google 14 | 15 | jobs: 16 | deploy: 17 | runs-on: ubuntu-latest 18 | if: github.repository_owner == 'cornserve-ai' 19 | steps: 20 | - name: Checkout repository 21 | uses: actions/checkout@v4 22 | - name: Setup Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: 3.11 26 | cache: 'pip' 27 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 28 | - uses: actions/cache@v4 29 | with: 30 | key: mkdocs-material-${{ env.cache_id }} 31 | path: .cache 32 | restore-keys: | 33 | mkdocs-material- 34 | - name: Install homepage build dependencies 35 | run: pip install -r docs/requirements.txt 36 | - name: Build and deploy homepage 37 | run: mkdocs gh-deploy --force 38 | env: 39 | BUILD_SOCIAL_CARD: true 40 | -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Python package to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | jobs: 9 | pypi-publish: 10 | runs-on: ubuntu-latest 11 | if: github.repository_owner == 'cornserve-ai' 12 | permissions: 13 | id-token: write 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v4 17 | with: 18 | submodules: recursive 19 | token: ${{ secrets.SUBMODULE_TOKEN }} 20 | - name: Setup Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: 3.11 24 | cache: 'pip' 25 | - name: Install protobuf dependencies 26 | run: pip install grpcio-tools 27 | - name: Generate protobuf files 28 | run: bash scripts/generate_pb.sh 29 | - name: Build source distribution 30 | run: cd python && pip install build && python -m build 31 | - name: Publish to PyPI 32 | uses: pypa/gh-action-pypi-publish@release/v1 33 | with: 34 | packages-dir: python/dist/ 35 | 36 | -------------------------------------------------------------------------------- /.github/workflows/push_docker.yaml: -------------------------------------------------------------------------------- 1 | name: Push Docker image 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | tags: 8 | - v* 9 | paths: 10 | - '.github/workflows/push_docker.yaml' 11 | - '.gitmodules' 12 | - 'docker/**' 13 | - 'python/**' 14 | - 'proto/**' 15 | - 'third_party/**' 16 | - 'scripts/generate_pb.sh' 17 | - '.dockerignore' 18 | - 'LICENSE' 19 | - 'setup.py' 20 | - 'pyproject.toml' 21 | 22 | env: 23 | NAMESPACE: cornserve 24 | 25 | jobs: 26 | build_and_push: 27 | if: github.repository_owner == 'cornserve-ai' 28 | runs-on: ${{ (matrix.component == 'vllm' || matrix.component == 'eric') && 'blacksmith-2vcpu-ubuntu-2404' || 'ubuntu-latest' }} 29 | strategy: 30 | matrix: 31 | component: 32 | - sidecar 33 | - task-dispatcher 34 | - task-manager 35 | - resource-manager 36 | - gateway 37 | - eric 38 | - vllm 39 | include: 40 | - component: eric 41 | build_args: | 42 | max_jobs=4 43 | build_target: eric 44 | - component: vllm 45 | build_target: vllm 46 | 47 | steps: 48 | - name: Remove unnecessary files 49 | run: | 50 | sudo rm -rf /usr/share/dotnet 51 | sudo rm -rf /opt/ghc 52 | sudo rm -rf "/usr/local/share/boost" 53 | sudo rm -rf "$AGENT_TOOLSDIRECTORY" 54 | 55 | - name: Checkout repository 56 | uses: actions/checkout@v4 57 | with: 58 | submodules: recursive 59 | 60 | - name: Setup Python for protobuf generation 61 | uses: actions/setup-python@v5 62 | with: 63 | python-version: 3.11 64 | 65 | - name: Install protobuf dependencies 66 | run: pip install grpcio-tools==1.71.0 67 | 68 | - name: Generate protobuf files 69 | run: bash scripts/generate_pb.sh 70 | 71 | - name: Docker Hub login 72 | uses: docker/login-action@v3 73 | with: 74 | username: ${{ secrets.DOCKER_HUB_USERNAME }} 75 | password: ${{ secrets.DOCKER_HUB_TOKEN }} 76 | 77 | - name: Set up Docker Buildx 78 | uses: docker/setup-buildx-action@v3 79 | 80 | - name: Generate image metadata 81 | id: meta 82 | uses: docker/metadata-action@v5 83 | with: 84 | images: | 85 | ${{ env.NAMESPACE }}/${{ matrix.component }} 86 | tags: | 87 | type=raw,value=latest,enable={{is_default_branch}} 88 | type=match,pattern=v.* 89 | 90 | - name: Build and push Docker image 91 | uses: docker/build-push-action@v6 92 | with: 93 | context: . 94 | file: docker/${{ matrix.component }}.Dockerfile 95 | push: true 96 | tags: ${{ steps.meta.outputs.tags }} 97 | labels: ${{ steps.meta.outputs.labels }} 98 | cache-from: type=registry,ref=${{ env.NAMESPACE }}/${{ matrix.component }}:buildcache 99 | cache-to: type=registry,ref=${{ env.NAMESPACE }}/${{ matrix.component }}:buildcache,mode=max 100 | platforms: linux/amd64 101 | build-args: ${{ matrix.build_args }} 102 | target: ${{ matrix.build_target || '' }} 103 | -------------------------------------------------------------------------------- /.github/workflows/python_lint.yaml: -------------------------------------------------------------------------------- 1 | name: Python format and lint check 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - '.github/workflows/python_lint.yaml' 7 | - 'python/**' 8 | - 'proto/**' 9 | - 'pyproject.toml' 10 | push: 11 | paths: 12 | - '.github/workflows/python_lint.yaml' 13 | - 'python/**' 14 | - 'proto/**' 15 | - 'pyproject.toml' 16 | 17 | # Jobs initiated by previous pushes get cancelled by a new push. 18 | concurrency: 19 | group: ${{ github.ref }}-python-format-and-lint 20 | cancel-in-progress: true 21 | 22 | jobs: 23 | format_lint: 24 | if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name != github.repository }} 25 | runs-on: ubuntu-latest 26 | steps: 27 | - name: Checkout repository 28 | uses: actions/checkout@v4 29 | - name: Setup Python 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: 3.11 33 | cache: 'pip' 34 | - name: Install library and lint tools 35 | run: pip install -U pip && pip install "./python[dev-no-gpu]" 36 | - name: Generate protobuf files 37 | run: bash scripts/generate_pb.sh 38 | - name: Check format and lint 39 | run: bash python/scripts/lint.sh 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # UV 99 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | #uv.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 117 | .pdm.toml 118 | .pdm-python 119 | .pdm-build/ 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .envrc 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | 172 | # Ruff stuff: 173 | .ruff_cache/ 174 | 175 | # PyPI configuration file 176 | .pypirc 177 | 178 | # Protobuf 179 | *pb2.py 180 | *pb2.pyi 181 | *pb2_grpc.py 182 | 183 | # Workspace 184 | /old 185 | /tmp 186 | /dev 187 | 188 | # MacOS 189 | .DS_Store 190 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/vllm"] 2 | path = third_party/vllm 3 | url = https://github.com/cornserve-ai/vllm.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Cornserve: Easy, Fast, and Scalable Multimodal AI

3 | 4 | [![Docker Hub](https://custom-icon-badges.demolab.com/badge/Docker-cornserve-1D63ED.svg?logo=docker&logoColor=white)](https://hub.docker.com/r/cornserve/gateway) 5 | [![Homepage](https://custom-icon-badges.demolab.com/badge/Docs-cornserve.ai-dddddd.svg?logo=home&logoColor=white&logoSource=feather)](https://cornserve.ai/) 6 | [![Apache-2.0 License](https://custom-icon-badges.herokuapp.com/github/license/cornserve-ai/cornserve?logo=law)](/LICENSE) 7 |
8 | 9 | https://github.com/user-attachments/assets/6dd12ad6-2307-4457-ae70-96e1cecc5ece 10 | 11 | Cornserve is an execution platform for multimodal AI. 12 | Split complex models into smaller separately scalable components (**model fission**) and share common components across multiple applications (**sharing**), all on your own infrastructure. 13 | 14 | 15 | ## Getting Started 16 | 17 | You can quickly try out Cornserve on top of Minikube. Check out our [getting started guide](https://cornserve.ai/getting_started/)! 18 | 19 | Cornserve can be deployed on Kubernetes with a single command. More on our [docs](https://cornserve.ai/getting_started/). 20 | 21 | 22 | ## Contributing 23 | 24 | Cornserve is an open-source project, and we welcome contributions! 25 | Please check out our [contributor guide](https://cornserve.ai/contributor_guide/) for more information on how to get started. 26 | -------------------------------------------------------------------------------- /docker/dev.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-devel 2 | 3 | RUN apt-get update && apt-get upgrade -y 4 | RUN apt-get install wget build-essential librdmacm-dev net-tools -y 5 | 6 | ########### Install UCX 1.18.0 ########### 7 | RUN wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz 8 | RUN tar xzf ucx-1.18.0.tar.gz 9 | WORKDIR /workspace/ucx-1.18.0 10 | RUN mkdir build 11 | RUN cd build && \ 12 | ../configure --build=x86_64-unknown-linux-gnu --host=x86_64-unknown-linux-gnu --program-prefix= --disable-dependency-tracking \ 13 | --prefix=/usr --exec-prefix=/usr --bindir=/usr/bin --sbindir=/usr/sbin --sysconfdir=/etc --datadir=/usr/share --includedir=/usr/include \ 14 | --libdir=/usr/lib64 --libexecdir=/usr/libexec --localstatedir=/var --sharedstatedir=/var/lib --mandir=/usr/share/man --infodir=/usr/share/info \ 15 | --disable-logging --disable-debug --disable-assertions --enable-mt --disable-params-check --without-go --without-java --enable-cma \ 16 | --with-verbs --with-mlx5 --with-rdmacm --without-rocm --with-xpmem --without-fuse3 --without-ugni --without-mad --without-ze && \ 17 | make -j$(nproc) && make install 18 | 19 | ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true 20 | ENV LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH 21 | 22 | # UCX logging 23 | ENV UCX_LOG_LEVEL=trace 24 | # UCX_LOG_LEVEL to be one of: fatal, error, warn, info, debug, trace, req, data, async, func, poll 25 | 26 | # UCX transports 27 | ENV UCX_TLS=rc,ib,tcp 28 | ########### End Install UCX ########### 29 | 30 | ADD . /workspace/cornserve 31 | WORKDIR /workspace/cornserve/python 32 | 33 | RUN pip install -e '.[dev]' 34 | 35 | # UCXX logging 36 | ENV UCXPY_LOG_LEVEL=DEBUG 37 | # python log level syntax: DEBUG, INFO, WARNING, ERROR, CRITICAL 38 | 39 | # Disable OpenTelemetry 40 | ENV OTEL_SDK_DISABLED=true 41 | 42 | CMD ["bash"] 43 | -------------------------------------------------------------------------------- /docker/eric.Dockerfile: -------------------------------------------------------------------------------- 1 | # Build flash-attn wheel inside the `devel` image which has `nvcc`. 2 | FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-devel AS builder 3 | 4 | ARG max_jobs=64 5 | ENV MAX_JOBS=${max_jobs} 6 | ENV NVCC_THREADS=8 7 | RUN pip wheel -w /tmp/wheels --no-build-isolation --no-deps --verbose flash-attn 8 | 9 | # Actual Eric runs inside the `runtime` image. Just copy over the flash-attn wheel. 10 | FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime AS eric 11 | 12 | COPY --from=builder /tmp/wheels/*.whl /tmp/wheels/ 13 | RUN pip install --no-cache-dir /tmp/wheels/*.whl && rm -rf /tmp/wheels 14 | 15 | RUN apt-get update \ 16 | && apt-get install -y --no-install-recommends build-essential \ 17 | && rm -rf /var/lib/apt/lists/* 18 | 19 | ADD . /workspace/cornserve 20 | 21 | WORKDIR /workspace/cornserve/python 22 | RUN pip install -e '.[eric]' 23 | 24 | ENTRYPOINT ["python", "-u", "-m", "cornserve.task_executors.eric.entrypoint"] 25 | 26 | # Eric that has audio support. 27 | FROM eric AS eric-audio 28 | 29 | RUN pip install -e '.[audio]' 30 | -------------------------------------------------------------------------------- /docker/gateway.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.11 2 | 3 | ADD . /workspace/cornserve 4 | 5 | WORKDIR /workspace/cornserve/python 6 | RUN pip install -e .[gateway] 7 | 8 | ENTRYPOINT ["python", "-m", "cornserve.services.gateway.entrypoint"] 9 | -------------------------------------------------------------------------------- /docker/resource-manager.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.11 2 | 3 | ADD . /workspace/cornserve 4 | 5 | WORKDIR /workspace/cornserve/python 6 | RUN pip install -e .[resource-manager] 7 | 8 | ENTRYPOINT ["python", "-m", "cornserve.services.resource_manager.entrypoint"] 9 | -------------------------------------------------------------------------------- /docker/sidecar.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime 2 | 3 | RUN apt-get update && apt-get upgrade -y 4 | RUN apt-get install wget build-essential librdmacm-dev net-tools -y 5 | 6 | ########### Install UCX 1.18.0 ########### 7 | RUN wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz 8 | RUN tar xzf ucx-1.18.0.tar.gz 9 | WORKDIR /workspace/ucx-1.18.0 10 | RUN mkdir build 11 | RUN cd build && \ 12 | ../configure --build=x86_64-unknown-linux-gnu --host=x86_64-unknown-linux-gnu --program-prefix= --disable-dependency-tracking \ 13 | --prefix=/usr --exec-prefix=/usr --bindir=/usr/bin --sbindir=/usr/sbin --sysconfdir=/etc --datadir=/usr/share --includedir=/usr/include \ 14 | --libdir=/usr/lib64 --libexecdir=/usr/libexec --localstatedir=/var --sharedstatedir=/var/lib --mandir=/usr/share/man --infodir=/usr/share/info \ 15 | --disable-logging --disable-debug --disable-assertions --enable-mt --disable-params-check --without-go --without-java --enable-cma \ 16 | --with-verbs --with-mlx5 --with-rdmacm --without-rocm --with-xpmem --without-fuse3 --without-ugni --without-mad --without-ze && \ 17 | make -j$(nproc) && make install 18 | 19 | ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true 20 | ENV LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH 21 | ########### End Install UCX ########### 22 | 23 | ADD . /workspace/cornserve 24 | 25 | WORKDIR /workspace/cornserve/python 26 | RUN pip install -e '.[sidecar]' 27 | 28 | ENTRYPOINT ["python", "-u", "-m", "cornserve.services.sidecar.server"] 29 | -------------------------------------------------------------------------------- /docker/task-dispatcher.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.11 2 | 3 | ADD . /workspace/cornserve 4 | 5 | WORKDIR /workspace/cornserve/python 6 | RUN pip install -e .[task-dispatcher] 7 | 8 | ENTRYPOINT ["python", "-m", "cornserve.services.task_dispatcher.entrypoint"] 9 | -------------------------------------------------------------------------------- /docker/task-manager.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11.11 2 | 3 | ADD . /workspace/cornserve 4 | 5 | WORKDIR /workspace/cornserve/python 6 | RUN pip install -e .[task-manager] 7 | 8 | ENTRYPOINT ["python", "-m", "cornserve.services.task_manager.entrypoint"] 9 | -------------------------------------------------------------------------------- /docker/vllm.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.8.1-devel-ubuntu22.04 AS base 2 | 3 | RUN apt-get update -y \ 4 | && apt-get install -y git curl wget \ 5 | && curl -LsSf https://astral.sh/uv/install.sh | sh 6 | 7 | ENV PATH="/root/.local/bin:$PATH" 8 | ENV VIRTUAL_ENV="/opt/venv" 9 | RUN uv venv --python 3.11 --seed ${VIRTUAL_ENV} 10 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 11 | 12 | ADD . /workspace/cornserve 13 | WORKDIR /workspace/cornserve/third_party/vllm 14 | 15 | # Install CORNSERVE sidecars 16 | RUN cd ../.. && uv pip install './python[sidecar-api]' 17 | 18 | RUN uv pip install -r requirements/common.txt 19 | RUN uv pip install -r requirements/cuda.txt 20 | ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.0.1.dev 21 | 22 | ENV VLLM_USE_PRECOMPILED=1 23 | ENV VLLM_COMMIT=6b6d4961147220fb80f9cc7dcb74db478f9c9a23 24 | ENV VLLM_PRECOMPILED_WHEEL_LOCATION=https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl 25 | 26 | # Intermediate vllm stage without audio 27 | FROM base AS vllm 28 | RUN uv pip install -e . 29 | 30 | ENV VLLM_USE_V1=1 31 | ENV HF_HOME="/root/.cache/huggingface" 32 | ENTRYPOINT ["vllm", "serve"] 33 | 34 | # Default final stage with audio support 35 | FROM vllm AS vllm-audio 36 | RUN uv pip install -e .[audio] 37 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | cornserve.ai 2 | -------------------------------------------------------------------------------- /docs/architecture/eric.md: -------------------------------------------------------------------------------- 1 | # Eric: Multimodal Data Embedding Server 2 | 3 | > **Mosharaf**: Hey, what is this "Eric" thing in the architecture diagram? 4 | **Jae-Won**: Oh, uh no, it says "Enc." For Encoder. 5 | **Mosharaf**: Oh. 6 | **Jae-Won**: Now it's Eric. 7 | 8 | Package: `cornserve.task_executors.eric` 9 | 10 | Eric is a multimodal data embedding server that takes in a list of multimodal data (e.g., images, videos) and computes the multimodal embedding of the input data. 11 | 12 | ## Architecture 13 | 14 | Below, components are divided at the process boundary. 15 | 16 | ### Router 17 | 18 | The gateway router is an async FastAPI server that (1) receives modality encoding requests and (2) preprocesses modality data before running the encoder. 19 | Preprocessing is done asynchronously in a thread pool by the `eric.router.processor.Processor` class. 20 | 21 | Each model processes different modality data differently, so the router must instantiate the correct model-specific preprocessor. 22 | Instantiating and invoking these model- and modality-specific preprocessors are implemented in the class `eric.models.[model_module].ModalityProcessor`, which is a subclass of `eric.models.base.BaseModalityProcessor`. 23 | 24 | When modality preprocessing is complete, the router submits the embedding request to the engine. 25 | The router and the engine communicate through ZMQ sockets. Especially, the router holds an instance of the engine client (`eric.engine.client.EngineClient`), which is used to send requests to the engine and receive responses. 26 | 27 | ### Engine 28 | 29 | From the engine and below, everything is synchronous Python (i.e., not `asyncio`). 30 | 31 | The Engine constantly receives embedding requests from the router, runs the request scheduler to create a `eric.schema.Batch`, and invokes the model executor (`eric.executor.executor.ModelExecutor`) to compute the multimodal embedding. 32 | The model executor provides the `execute_model` method, which broadcasts input batch data to all Workers via shared memory. 33 | 34 | The engine currently only batches data of the same modality together. This is because there are models that have different code paths for different modalities. Furthermore, due to the compute-intensive nature of multimodal encoders, it is unlikely we will scale to large batch sizes. 35 | 36 | ### Workers 37 | 38 | There is one worker (`eric.executor.worker.Worker`) process per GPU. The number of workers is the tensor parallelism degree. 39 | When spawned, the workers initialize PyTorch distributed and instantiate the model from weights downloaded from the Hugging Face Hub. 40 | It then waits for the model executor to dispatch a batch to it, runs tensor parallel inference, and dispatches tensor communication to the destination Task Executor via the [sidecar](sidecar.md). 41 | -------------------------------------------------------------------------------- /docs/assets/css/extra.css: -------------------------------------------------------------------------------- 1 | /* Override mkdocs-video plugin styles for responsive margins */ 2 | .video-container { 3 | margin: 0 auto !important; 4 | } 5 | 6 | /* Desktop screens */ 7 | @media (min-width: 1024px) { 8 | .video-container { 9 | padding: 0 12% !important; 10 | } 11 | } 12 | 13 | /* Mobile - use plugin default */ 14 | @media (max-width: 767px) { 15 | .video-container { 16 | padding: 0 !important; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /docs/assets/img/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/docs/assets/img/favicon.png -------------------------------------------------------------------------------- /docs/assets/js/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) 17 | -------------------------------------------------------------------------------- /docs/assets/video/cornserve.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/docs/assets/video/cornserve.mp4 -------------------------------------------------------------------------------- /docs/contributor_guide/eric.md: -------------------------------------------------------------------------------- 1 | # Eric developer guide 2 | 3 | ## Docker container 4 | 5 | All code is to be run inside a Docker container, including tests. 6 | 7 | ```bash 8 | docker build -t cornserve/eric:latest -f docker/eric.Dockerfile . 9 | docker run -it --gpus all --entrypoint bash --ipc host --rm --name eric-dev -v $PWD:/workspace/cornserve -v $HF_CACHE:/root/.cache/huggingface cornserve/eric:latest 10 | ``` 11 | 12 | ## Editable installation 13 | 14 | ```bash 15 | pip install -e 'python[dev]' 16 | ``` 17 | 18 | ## Testing 19 | 20 | We use `pytest`. Tests use GPUs. 21 | 22 | ```bash 23 | pytest 24 | ``` 25 | 26 | Set the `CORNSERVE_TEST_DUMP_TENSOR_DIR` to an existing directory when running `pytest`. 27 | This will dump output embedding tensors to the specified directory. 28 | Refer to `build_batch` in `tests/task_executors/eric/utils.py`. 29 | 30 | ```bash 31 | export CORNERSERVE_TEST_DUMP_TENSOR_DIR=/path/to/dump 32 | pytest python/tests/task_executors/eric/models/test_llava_onevision.py::test_image_inference 33 | ``` 34 | -------------------------------------------------------------------------------- /docs/contributor_guide/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | description: Cornserve contributor guide 3 | --- 4 | 5 | # Contributor Guide 6 | 7 | Here, we provide more info for contributors. 8 | General principles are here, and child pages discuss specific topics in more detail. 9 | 10 | We have a few principles for developing Cornserve: 11 | 12 | 1. **Strict type annotations**: We enforce strict type annotation everywhere in the Python codebase, which leads to numerous benefits including better reliability, readability, and editor support. We use `pyright` for type checking. 13 | 1. **Automated testing**: We don't aim for 100% test coverage, but non-trivial and/or critical features should be tested with `pytest`. 14 | 15 | ## Contributing process 16 | 17 | !!! Important 18 | By contributing to Cornserve, you agree that your code will be licensed with Apache 2.0. 19 | 20 | If the feature is not small or requires broad changes over the codebase, please **open an issue** at our GitHub repository to discuss with us. 21 | 22 | 1. Fork our GitHub repository. Make sure you clone with `--recurse-submodules` to get the submodules. 23 | 1. Create a new Conda environment with something along the lines of `conda create -n cornserve python=3.11` and activate it with something like `conda activate cornserve`. 24 | 1. Install Cornserve in editable mode with `pip install -e 'python[dev]'`. If your environment does not have GPUs, you can use `pip install -e 'python[dev-no-gpu]'`. 25 | 1. Generate Python bindings for Protobuf files with `bash scripts/generate_pb.sh`. 26 | 1. Implement changes in your branch and add tests as needed. 27 | 1. Ensure `bash python/scripts/lint.sh` and `pytest` passes. Note that many of our tests require GPU. 28 | 1. Submit a PR to the main repository. Please ensure that CI (GitHub Actions) passes. 29 | 30 | ## Developing on Kubernetes 31 | 32 | Cornserve runs on top of Kubernetes, which introduces some complexity in development. 33 | Please refer to the guide on [Local and Distributed Development on Kubernetes](kubernetes.md) for more details. 34 | 35 | ## Documentation 36 | 37 | The documentation is written in Markdown and is located in the `docs` folder. 38 | We use MkDocs to build the documentation and use the `mkdocs-material` theme. 39 | 40 | To install documentation build dependencies: 41 | 42 | ```bash 43 | pip install -r docs/requirements.txt 44 | ``` 45 | 46 | To build and preview the documentation: 47 | 48 | ```bash 49 | mkdocs serve 50 | ``` 51 | -------------------------------------------------------------------------------- /docs/contributor_guide/kubernetes.md: -------------------------------------------------------------------------------- 1 | ## Local and Distributed Development on Kubernetes 2 | 3 | ### Local development 4 | 5 | You are developing on a single node. 6 | In this case, we don't need a registry. 7 | Instead, we build containers directly within the containerd runtime of K3s. 8 | 9 | First, follow [this guide](https://blog.otvl.org/blog/k3s-loc-sp) (Section "Switching from Docker to Containerd") to set up Nerdctl and BuildKit on your local development machine. 10 | 11 | After that, you can use Nerdctl to build images directly within K3s containerd, and no pull is necessary whatsoever. 12 | Use the `build_export_images.sh` script with the `REGISTRY` environment variable set to `local` (a special case): 13 | 14 | ```bash 15 | REGISTRY=local bash scripts/build_export_images.sh 16 | ``` 17 | 18 | Use the `local` overlay to deploy Cornserve: 19 | 20 | ```bash 21 | kubectl apply -k kustomize/cornserve-system/overlays/local 22 | kubectl apply -k kustomize/cornserve/overlays/local 23 | ``` 24 | 25 | The `local` overlay specifies `imagePullPolicy: Never`, meaning that if the image was not found locally, it means that it was not built yet, correctly raising an error. 26 | 27 | !!! NOTE 28 | You can use the `local` overlay for the quick Minikube demo as well. 29 | 30 | ### Distributed development 31 | 32 | You are developing on a multi-node cluster. 33 | 34 | (1) Now, you do need a registry to push images to, so that remote nodes can pull them: 35 | 36 | ```bash 37 | bash kubernetes/registry.sh 38 | REGISTRY=myregisty.com:5000 bash kubernetes/set_registry.sh # (1)! 39 | ``` 40 | 41 | 1. Modifies `kustomization.yaml` and `k3s/registries.yaml` 42 | If you're on this dev workflow with a *single* node cluster, you can skip `kubernetes/set_registry.sh` because things default to `localhost:5000`. 43 | 44 | (2) For K3s to work with insecure (i.e., HTTP) registries, you need to set up the `registries.yaml` file in `/etc/rancher/k3s/` on **all** nodes (master and worker) before starting K3s: 45 | 46 | ```bash 47 | sudo cp kubernetes/k3s/registries.yaml /etc/rancher/k3s/registries.yaml 48 | sudo systemctl start k3s # or k3s-agent 49 | ``` 50 | 51 | (3) Build and push images to the registry using the `build_export_images.sh` script with the `REGISTRY` environment variable set to the registry address: 52 | 53 | ```bash 54 | REGISTRY=myregistry.com:5000 bash scripts/build_export_images.sh 55 | ``` 56 | 57 | (4) Use the `dev` overlay (which specifies `imagePullPolicy: Always`) to deploy Cornserve: 58 | 59 | ```bash 60 | kubectl apply -k kustomize/cornserve-system/overlays/dev 61 | kubectl apply -k kustomize/cornserve/overlays/dev 62 | ``` 63 | -------------------------------------------------------------------------------- /docs/contributor_guide/sidecar.md: -------------------------------------------------------------------------------- 1 | # Sidecar developer guide 2 | 3 | ## Docker container 4 | 5 | It is recommended to run everything inside docker. Sidecar uses `UCX` as backend, 6 | so you might find the `docker/dev.Dockerfile` helpful. Additionally, Sidecar has 7 | dependency over `ucxx-cu12`, meaning you need to development on an Nvidia 8 | GPU-enabled machine at the moment. 9 | 10 | Specifying `--shm-size` with at least 4 GB and `--ipc host` is required. 11 | 12 | ## Editable installation 13 | 14 | ```bash 15 | pip install -e 'python[dev]' 16 | ``` 17 | 18 | ## Testing 19 | 20 | We use pytest. 21 | 22 | ```bash 23 | pytest python/tests/services/sidecar/test_sidecar.py 24 | ``` 25 | 26 | When testing locally with task executors, you can `export SIDECAR_IS_LOCAL=true` to 27 | route all communications through `localhost` instead of k8s network. 28 | 29 | 30 | ## Debugging 31 | 32 | To debug UCX related error, you can set `UCX_LOG_LEVEL=trace` and `UCXPY_LOG_LEVEL=DEBUG` 33 | -------------------------------------------------------------------------------- /docs/contributor_guide/tracing.md: -------------------------------------------------------------------------------- 1 | # Tracing Developer guide 2 | 3 | We employ OpenTelemetry for observability. Below are some of the conventions we use. 4 | 5 | Generally, we use auto-instrumentation provided by OpenTelemetry, e.g., FastAPI, gRPC, HTTPX. 6 | 7 | ## Spans 8 | 9 | Usually named with `ClassName.function_name`. 10 | 11 | ## Attributes 12 | 13 | Usually named with `namespace.subroutine.attribute_name`. 14 | `namespace` is typically the name of the service, like `gateway`. 15 | 16 | ## Events 17 | 18 | Usually named with `action.event_name`. 19 | Use spans for things that happen over time (e.g., a subroutine), where tracking the start and end is important. 20 | On the other hand, use events for singular occurrences that happen at a specific moment in time. 21 | 22 | ## Test 23 | When testing locally, you can disable OTEL tracing through `OTEL_SDK_DISABLED=true`. 24 | -------------------------------------------------------------------------------- /docs/getting_started/cornserve.md: -------------------------------------------------------------------------------- 1 | # Deploying Cornserve 2 | 3 | Cornserve can be deployed on a GPU cluster managed by Kubernetes. 4 | 5 | !!! Note 6 | The `cornserve` namespace is used for most of our control plane and data plane objects. 7 | On the other hand, the `cornserve-system` namespace is used for components that look over and manage the Cornserve system itself (under `cornserve`), like Jaeger and Prometheus. 8 | If you already have a Kubernetes cluster running, you can deploy Cornserve on it with the `prod` overlay: 9 | 10 | ## Deploying K3s 11 | 12 | !!! Tip 13 | If you have a Kubernetes cluster running, you can skip this section. 14 | 15 | If you don't have a Kubernetes cluster running, you can deploy Cornserve on a K3s cluster. 16 | We also use the [K3s](https://k3s.io/) distribution of Kubernetes for our development. 17 | Refer to their [Documentation](https://docs.k3s.io/quick-start/) for more details. 18 | 19 | !!! Tip 20 | If you're deploying on-premise with k3s, make sure you have plenty of disk space under `/var/lib/rancher` because `containerd` stores images there. 21 | If not, you can create a directory in a secondary storage (e.g., `/mnt/data/rancher`) and symlink it to `/var/lib/rancher` prior to starting k3s. 22 | 23 | ### NVIDIA Device Plugin 24 | 25 | The [NVIDIA GPU Device Plugin](https://github.com/NVIDIA/k8s-device-plugin) is required to expose GPUs to the Kubernetes cluster as resources. 26 | You can deploy a specific version like this: 27 | 28 | ```bash 29 | kubectl create -f https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.17.2/deployments/static/nvidia-device-plugin.yml 30 | ``` 31 | 32 | ### Clone the Repository 33 | 34 | ```bash 35 | git clone https://github.com/cornserve-ai/cornserve.git 36 | cd cornserve/kubernetes 37 | ``` 38 | 39 | ### Master Node 40 | 41 | Install and start K3s: 42 | 43 | ```bash 44 | curl -sfL https://get.k3s.io | INSTALL_K3S_SKIP_ENABLE=true sh - 45 | sudo mkdir -p /etc/rancher/k3s 46 | sudo cp k3s/server-config.yaml /etc/rancher/k3s/config.yaml 47 | sudo systemctl start k3s 48 | ``` 49 | 50 | Note the master node address (`$MASTER_ADDRESS`) and the node token (`$NODE_TOKEN`): 51 | 52 | ```bash 53 | NODE_TOKEN="$(sudo cat /var/lib/rancher/k3s/server/node-token)" 54 | ``` 55 | 56 | ### Worker Nodes 57 | 58 | Install and start K3s: 59 | 60 | ```bash 61 | curl -sfL https://get.k3s.io | K3S_URL=https://$MASTER_ADDRESS:6443 K3S_TOKEN=$NODE_TOKEN INSTALL_K3S_SKIP_ENABLE=true sh - 62 | sudo mkdir -p /etc/rancher/k3s 63 | sudo cp k3s/agent-config.yaml /etc/rancher/k3s/config.yaml 64 | sudo systemctl start k3s-agent 65 | ``` 66 | 67 | ## Deploying Cornserve 68 | 69 | If you haven't already, clone the Cornserve repository: 70 | 71 | ```bash 72 | git clone git@github.com:cornserve-ai/cornserve.git 73 | cd cornserve 74 | ``` 75 | 76 | On top of a Kubernetes cluster, you can deploy Cornserve with a single command: 77 | 78 | ```bash 79 | kubectl apply -k kubernetes/kustomize/cornserve-system/base 80 | kubectl apply -k kubernetes/kustomize/cornserve/overlays/prod 81 | ``` 82 | 83 | !!! Note 84 | The `cornserve` namespace is used for most of our control plane and data plane objects. 85 | On the other hand, the `cornserve-system` namespace is used for components that look over and manage the Cornserve system itself (under `cornserve`), like Jaeger and Prometheus. 86 | -------------------------------------------------------------------------------- /docs/getting_started/registering_apps.md: -------------------------------------------------------------------------------- 1 | # Deploying Apps to Cornserve and Invoking Them 2 | 3 | Once you've written your app, you can deploy it to Cornserve. 4 | The current deployment process is as follows: 5 | 6 | 1. Save the app code in a single Python file (e.g., `image_chat.py`). 7 | 2. Register & deploy the app to the Cornserve Gateway for validation and deployment: 8 | ```bash 9 | export CORNSERVE_GATEWAY_URL=[...] 10 | cornserve register image_chat.py 11 | ``` 12 | 3. When validation succeeds, the Cornserve Gateway will deploy the app and all its subtasks on the Cornserve data plane, and the `cornserve` CLI invocation will return with the app's ID. 13 | 4. Finally, you can send requests to the Cornserve Gateway with your choice of HTTP client. 14 | ```python 15 | response = requests.post( 16 | f"{CORNSERVE_GATEWAY_URL}/app/invoke/{APP_ID}", 17 | json={ 18 | "request_data": { 19 | "image_url": "https://example.com/image.jpg", 20 | "prompt": "Describe the image.", 21 | } 22 | }, 23 | ) 24 | ``` 25 | Notice that what comes within the `"request_data"` key is the JSON representation of your `Request` class defined in our [previous example](building_apps.md#app). 26 | 27 | ## Next Steps 28 | 29 | To dive deeper into the architecture of Cornserve, check out our [architecture guide](../architecture/index.md). 30 | 31 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | description: "Easy, fast, and scalable multimodal AI" 3 | hide: 4 | - navigation 5 | - toc 6 | --- 7 | 8 |
9 |

Cornserve: Easy, Fast, and Scalable Multimodal AI

10 |
11 | 12 |
13 | 22 |
23 | 24 | 43 | 44 | Cornserve is an execution platform for multimodal AI. 45 | It performs **model fission** and **automatic sharing** of common components across applications on your infrastructure. 46 | 47 |
48 | 49 | - :material-vector-intersection:{ .lg .middle } **Model fission** 50 | 51 | --- 52 | 53 | Split up your complex models into smaller components and 54 | scale them independently. 55 | 56 | - :material-share-variant:{ .lg .middle } **Automatic sharing** 57 | 58 | --- 59 | 60 | Common model components are automatically shared across applications. 61 | 62 | - :material-hub:{ .lg .middle } **Multimodal-native** 63 | 64 | --- 65 | 66 | Cornserve is built multimodal-native from the ground up. Image, video, 67 | audio, and text are all first-class citizens. 68 | 69 | - :material-kubernetes:{ .lg .middle } **Deploy to K8s with one command** 70 | 71 | --- 72 | 73 | One-command deployment to Kubernetes with [Kustomize](https://kustomize.io/). 74 | 75 | - :simple-opentelemetry:{ .lg .middle } **Observability** 76 | 77 | --- 78 | 79 | Built-in support for [OpenTelemetry](https://opentelemetry.io/) 80 | to monitor your apps and requests. 81 | 82 | - :material-scale-balance:{ .lg .middle } **Open Source, Apache-2.0** 83 | 84 | --- 85 | 86 | Cornserve is open-source with the Apache 2.0 license and is available on 87 | [GitHub](https://github.com/cornserve-ai/cornserve). 88 | 89 |
90 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material[imaging] 3 | mkdocs-video 4 | mkdocs-autorefs 5 | mkdocs-jupyter 6 | -------------------------------------------------------------------------------- /examples/mllm/app.py: -------------------------------------------------------------------------------- 1 | """An app that runs a Multimodal LLM task.""" 2 | 3 | from __future__ import annotations 4 | 5 | from cornserve.app.base import AppRequest, AppResponse, AppConfig 6 | from cornserve.task.builtins.mllm import MLLMInput, MLLMTask, Modality 7 | 8 | 9 | class Request(AppRequest): 10 | """App request model. 11 | 12 | Attributes: 13 | prompt: The prompt to send to the LLM. 14 | multimodal_data: List of tuples (modality, data URL). 15 | "image", "video", etc. for modality. 16 | """ 17 | 18 | prompt: str 19 | multimodal_data: list[tuple[str, str]] = [] 20 | 21 | 22 | class Response(AppResponse): 23 | """App response model. 24 | 25 | Attributes: 26 | response: The response from the LLM. 27 | """ 28 | 29 | response: str 30 | 31 | 32 | mllm = MLLMTask( 33 | # model_id="Qwen/Qwen2-VL-7B-Instruct", 34 | model_id="google/gemma-3-4b-it", 35 | modalities=[Modality.IMAGE], 36 | ) 37 | 38 | 39 | class Config(AppConfig): 40 | """App configuration model.""" 41 | 42 | tasks = {"mllm": mllm} 43 | 44 | 45 | async def serve(request: Request) -> Response: 46 | """Main serve function for the app.""" 47 | mllm_input = MLLMInput(prompt=request.prompt, multimodal_data=request.multimodal_data) 48 | mllm_output = await mllm(mllm_input) 49 | return Response(response=mllm_output.response) 50 | -------------------------------------------------------------------------------- /examples/mllm/requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | cornserve 3 | -------------------------------------------------------------------------------- /kubernetes/README.md: -------------------------------------------------------------------------------- 1 | # Kubernetes 2 | 3 | Please refer to our documentation for guides on [deploying](https://cornserve.ai/getting_started/cornserve/) and [developing](https://cornserve.ai/contributor_guide/kubernetes/) on Kubernetes. 4 | -------------------------------------------------------------------------------- /kubernetes/k3s/agent-config.yaml: -------------------------------------------------------------------------------- 1 | default-runtime: "nvidia" 2 | -------------------------------------------------------------------------------- /kubernetes/k3s/registries.yaml: -------------------------------------------------------------------------------- 1 | mirrors: 2 | "localhost:5000": 3 | endpoint: 4 | - "http://localhost:5000" 5 | -------------------------------------------------------------------------------- /kubernetes/k3s/server-config.yaml: -------------------------------------------------------------------------------- 1 | write-kubeconfig-mode: "0644" 2 | default-runtime: "nvidia" 3 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/base/jaeger/configmap.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ConfigMap 3 | metadata: 4 | name: jaeger-config 5 | namespace: cornserve-system 6 | data: 7 | config.yaml: | 8 | service: 9 | extensions: [jaeger_storage, jaeger_query, healthcheckv2] 10 | pipelines: 11 | traces: 12 | receivers: [otlp] 13 | processors: [batch] 14 | exporters: [jaeger_storage_exporter] 15 | telemetry: 16 | resource: 17 | service.name: jaeger 18 | metrics: 19 | level: detailed 20 | logs: 21 | level: info 22 | 23 | extensions: 24 | healthcheckv2: 25 | use_v2: true 26 | http: 27 | 28 | jaeger_query: 29 | max_clock_skew_adjust: 30s 30 | storage: 31 | traces: badger_store 32 | 33 | jaeger_storage: 34 | backends: 35 | badger_store: 36 | badger: 37 | directories: 38 | keys: "/badger/keys" 39 | values: "/badger/values" 40 | ephemeral: false 41 | 42 | receivers: 43 | otlp: 44 | protocols: 45 | grpc: 46 | endpoint: 0.0.0.0:4317 47 | http: 48 | endpoint: 0.0.0.0:4318 49 | 50 | processors: 51 | batch: 52 | 53 | exporters: 54 | jaeger_storage_exporter: 55 | trace_storage: badger_store 56 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/base/jaeger/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: jaeger 5 | namespace: cornserve-system 6 | labels: 7 | app: jaeger 8 | spec: 9 | selector: 10 | matchLabels: 11 | app: jaeger 12 | replicas: 1 13 | template: 14 | metadata: 15 | labels: 16 | app: jaeger 17 | spec: 18 | nodeSelector: 19 | node-role.kubernetes.io/control-plane: "true" 20 | containers: 21 | - name: jaeger 22 | image: jaegertracing/jaeger:2.4.0 23 | imagePullPolicy: IfNotPresent 24 | securityContext: 25 | runAsUser: 0 26 | runAsGroup: 0 27 | args: 28 | - "--config=/config/config.yaml" 29 | ports: 30 | - containerPort: 16686 31 | name: query 32 | - containerPort: 4317 33 | name: otlp-grpc 34 | - containerPort: 4318 35 | name: otlp-http 36 | volumeMounts: 37 | - name: config-volume 38 | mountPath: /config 39 | resources: 40 | limits: 41 | cpu: 500m 42 | memory: 1Gi 43 | requests: 44 | cpu: 100m 45 | memory: 200Mi 46 | volumes: 47 | - name: config-volume 48 | configMap: 49 | name: jaeger-config 50 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/base/jaeger/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - configmap.yaml 5 | - deployment.yaml 6 | - service.yaml 7 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/base/jaeger/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: jaeger-collector 5 | namespace: cornserve-system 6 | labels: 7 | app: jaeger 8 | spec: 9 | selector: 10 | app: jaeger 11 | ports: 12 | - port: 4317 13 | targetPort: 4317 14 | name: otlp-grpc 15 | - port: 4318 16 | targetPort: 4318 17 | name: otlp-http 18 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/base/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - namespace.yaml 5 | - jaeger 6 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/base/namespace.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Namespace 3 | metadata: 4 | creationTimestamp: null 5 | name: cornserve-system 6 | spec: {} 7 | status: {} 8 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/dev/jaeger/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1alpha1 2 | kind: Component 3 | resources: 4 | - node-port-service.yaml 5 | patches: 6 | - path: volume-patch.yaml 7 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/dev/jaeger/node-port-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: jaeger-query 5 | namespace: cornserve-system 6 | labels: 7 | app: jaeger 8 | spec: 9 | selector: 10 | app: jaeger 11 | type: NodePort 12 | ports: 13 | - port: 16686 14 | targetPort: 16686 15 | nodePort: 30686 16 | name: query 17 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/dev/jaeger/volume-patch.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: jaeger 5 | namespace: cornserve-system 6 | spec: 7 | template: 8 | spec: 9 | containers: 10 | - name: jaeger 11 | volumeMounts: 12 | - name: host-storage 13 | mountPath: /badger 14 | volumes: 15 | - name: host-storage 16 | hostPath: 17 | path: /data/cornserve/jaeger-badger 18 | type: DirectoryOrCreate 19 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/dev/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - ../../base 5 | components: 6 | - jaeger 7 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/local/jaeger/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1alpha1 2 | kind: Component 3 | resources: 4 | - node-port-service.yaml 5 | patches: 6 | - path: volume-patch.yaml 7 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/local/jaeger/node-port-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: jaeger-query 5 | namespace: cornserve-system 6 | labels: 7 | app: jaeger 8 | spec: 9 | selector: 10 | app: jaeger 11 | type: NodePort 12 | ports: 13 | - port: 16686 14 | targetPort: 16686 15 | nodePort: 30686 16 | name: query 17 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/local/jaeger/volume-patch.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: jaeger 5 | namespace: cornserve-system 6 | spec: 7 | template: 8 | spec: 9 | containers: 10 | - name: jaeger 11 | volumeMounts: 12 | - name: host-storage 13 | mountPath: /badger 14 | volumes: 15 | - name: host-storage 16 | hostPath: 17 | path: /data/cornserve/jaeger-badger 18 | type: DirectoryOrCreate 19 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/local/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - ../../base 5 | components: 6 | - jaeger 7 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/minikube/jaeger/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1alpha1 2 | kind: Component 3 | resources: 4 | - node-port-service.yaml 5 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/minikube/jaeger/node-port-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: jaeger-query 5 | namespace: cornserve-system 6 | labels: 7 | app: jaeger 8 | spec: 9 | selector: 10 | app: jaeger 11 | type: NodePort 12 | ports: 13 | - port: 16686 14 | targetPort: 16686 15 | nodePort: 30686 16 | name: query 17 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve-system/overlays/minikube/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - ../../base 5 | patches: 6 | - target: 7 | kind: Deployment 8 | labelSelector: app=jaeger 9 | patch: |- 10 | - op: remove 11 | path: /spec/template/spec/nodeSelector 12 | components: 13 | - jaeger 14 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/gateway/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: gateway 5 | namespace: cornserve 6 | labels: 7 | app: gateway 8 | spec: 9 | selector: 10 | matchLabels: 11 | app: gateway 12 | replicas: 1 13 | template: 14 | metadata: 15 | labels: 16 | app: gateway 17 | spec: 18 | nodeSelector: 19 | node-role.kubernetes.io/control-plane: "true" 20 | containers: 21 | - name: gateway 22 | image: cornserve/gateway:latest 23 | imagePullPolicy: IfNotPresent 24 | ports: 25 | - containerPort: 8000 26 | name: http 27 | envFrom: 28 | - configMapRef: 29 | name: cornserve-config 30 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/gateway/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - deployment.yaml 5 | - service.yaml 6 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/gateway/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: gateway 5 | namespace: cornserve 6 | spec: 7 | selector: 8 | app: gateway 9 | ports: 10 | - name: gateway 11 | port: 8000 12 | targetPort: http 13 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - gateway 5 | - resource-manager 6 | - sidecar 7 | - task-dispatcher 8 | - task-manager 9 | - namespace.yaml 10 | images: 11 | - name: cornserve/gateway 12 | newName: cornserve/gateway 13 | - name: cornserve/resource-manager 14 | newName: cornserve/resource-manager 15 | - name: cornserve/sidecar 16 | newName: cornserve/sidecar 17 | - name: cornserve/task-dispatcher 18 | newName: cornserve/task-dispatcher 19 | - name: cornserve/task-manager 20 | newName: cornserve/task-manager 21 | configMapGenerator: 22 | - name: cornserve-config 23 | namespace: cornserve 24 | literals: 25 | - CORNSERVE_IMAGE_PREFIX=docker.io/cornserve 26 | - CORNSERVE_IMAGE_TAG=latest 27 | - CORNSERVE_IMAGE_PULL_POLICY=IfNotPresent 28 | generatorOptions: 29 | disableNameSuffixHash: true 30 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/namespace.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Namespace 3 | metadata: 4 | creationTimestamp: null 5 | name: cornserve 6 | spec: {} 7 | status: {} 8 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/resource-manager/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: resource-manager 5 | namespace: cornserve 6 | labels: 7 | app: resource-manager 8 | spec: 9 | selector: 10 | matchLabels: 11 | app: resource-manager 12 | replicas: 1 13 | template: 14 | metadata: 15 | labels: 16 | app: resource-manager 17 | spec: 18 | nodeSelector: 19 | node-role.kubernetes.io/control-plane: "true" 20 | serviceAccountName: resource-manager 21 | containers: 22 | - name: resource-manager 23 | image: cornserve/resource-manager:latest 24 | imagePullPolicy: IfNotPresent 25 | ports: 26 | - containerPort: 50051 27 | name: grpc 28 | envFrom: 29 | - configMapRef: 30 | name: cornserve-config 31 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/resource-manager/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - deployment.yaml 5 | - role-binding.yaml 6 | - role.yaml 7 | - service-account.yaml 8 | - service.yaml 9 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/resource-manager/role-binding.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRoleBinding 3 | metadata: 4 | name: resource-manager-binding 5 | subjects: 6 | - kind: ServiceAccount 7 | name: resource-manager 8 | namespace: cornserve 9 | roleRef: 10 | kind: ClusterRole 11 | name: resource-manager-role 12 | apiGroup: rbac.authorization.k8s.io 13 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/resource-manager/role.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRole 3 | metadata: 4 | name: resource-manager-role 5 | rules: 6 | - apiGroups: [""] 7 | resources: ["pods", "services"] 8 | verbs: ["*"] 9 | - apiGroups: [""] 10 | resources: ["nodes"] 11 | verbs: ["list"] 12 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/resource-manager/service-account.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ServiceAccount 3 | metadata: 4 | name: resource-manager 5 | namespace: cornserve 6 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/resource-manager/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: resource-manager 5 | namespace: cornserve 6 | spec: 7 | selector: 8 | app: resource-manager 9 | ports: 10 | - name: resource-manager 11 | port: 50051 12 | targetPort: grpc 13 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/sidecar/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - service.yaml 5 | - role-binding.yaml 6 | - role.yaml 7 | - service-account.yaml 8 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/sidecar/role-binding.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: RoleBinding 3 | metadata: 4 | name: sidecar-binding 5 | namespace: cornserve 6 | subjects: 7 | - kind: ServiceAccount 8 | name: sidecar 9 | roleRef: 10 | kind: Role 11 | name: sidecar-role 12 | apiGroup: rbac.authorization.k8s.io 13 | 14 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/sidecar/role.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: Role 3 | metadata: 4 | name: sidecar-role 5 | namespace: cornserve 6 | rules: 7 | - apiGroups: [""] 8 | resources: ["pods"] 9 | verbs: ["list", "get"] 10 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/sidecar/service-account.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ServiceAccount 3 | metadata: 4 | name: sidecar 5 | namespace: cornserve 6 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/sidecar/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: sidecar 5 | namespace: cornserve 6 | labels: 7 | app: sidecar 8 | spec: 9 | clusterIP: None 10 | selector: 11 | app: sidecar 12 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/sidecar/statefulset.yaml: -------------------------------------------------------------------------------- 1 | # This file is not used by Kustomize; sidecars are deployed by the Resource Manager. 2 | # We keep it here for reference. We might want to deploy only the sidecars, for instance. 3 | apiVersion: apps/v1 4 | kind: StatefulSet 5 | metadata: 6 | name: sidecar 7 | namespace: cornserve 8 | labels: 9 | app: sidecar 10 | spec: 11 | serviceName: sidecar 12 | selector: 13 | matchLabels: 14 | app: sidecar 15 | replicas: 4 16 | template: 17 | metadata: 18 | labels: 19 | app: sidecar 20 | spec: 21 | hostPID: true 22 | hostIPC: true 23 | topologySpreadConstraints: 24 | - maxSkew: 1 25 | topologyKey: "kubernetes.io/hostname" 26 | whenUnsatisfiable: DoNotSchedule 27 | labelSelector: 28 | matchLabels: 29 | app: sidecar 30 | runtimeClassName: nvidia 31 | serviceAccountName: sidecar 32 | containers: 33 | - name: sidecar 34 | image: cornserve/sidecar:latest 35 | imagePullPolicy: IfNotPresent 36 | securityContext: 37 | privileged: true 38 | env: 39 | - name: SIDECAR_WORLD_SIZE 40 | value: "4" 41 | - name: SIDECAR_POD_NAME 42 | valueFrom: 43 | fieldRef: 44 | fieldPath: metadata.name 45 | envFrom: 46 | - configMapRef: 47 | name: cornserve-config 48 | volumeMounts: 49 | - name: shm-volume 50 | mountPath: /dev/shm 51 | - name: infiniband-class 52 | mountPath: /sys/class/infiniband 53 | - name: infiniband-dev 54 | mountPath: /dev/infiniband 55 | volumes: 56 | - name: shm-volume 57 | hostPath: 58 | path: /dev/shm 59 | type: Directory 60 | - name: infiniband-class 61 | hostPath: 62 | path: /sys/class/infiniband 63 | type: Directory 64 | - name: infiniband-dev 65 | hostPath: 66 | path: /dev/infiniband 67 | type: Directory 68 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/task-dispatcher/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: task-dispatcher 5 | namespace: cornserve 6 | labels: 7 | app: task-dispatcher 8 | spec: 9 | selector: 10 | matchLabels: 11 | app: task-dispatcher 12 | replicas: 1 13 | template: 14 | metadata: 15 | labels: 16 | app: task-dispatcher 17 | spec: 18 | nodeSelector: 19 | node-role.kubernetes.io/control-plane: "true" 20 | containers: 21 | - name: task-dispatcher 22 | image: cornserve/task-dispatcher:latest 23 | imagePullPolicy: IfNotPresent 24 | ports: 25 | - containerPort: 50051 26 | name: grpc 27 | - containerPort: 8000 28 | name: http 29 | envFrom: 30 | - configMapRef: 31 | name: cornserve-config 32 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/task-dispatcher/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - deployment.yaml 5 | - service.yaml 6 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/task-dispatcher/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: task-dispatcher 5 | namespace: cornserve 6 | spec: 7 | selector: 8 | app: task-dispatcher 9 | ports: 10 | - name: grpc 11 | port: 50051 12 | targetPort: grpc 13 | - name: http 14 | port: 8000 15 | targetPort: http 16 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/task-manager/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - role-binding.yaml 5 | - role.yaml 6 | - service-account.yaml 7 | 8 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/task-manager/role-binding.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: RoleBinding 3 | metadata: 4 | name: task-manager-binding 5 | namespace: cornserve 6 | subjects: 7 | - kind: ServiceAccount 8 | name: task-manager 9 | roleRef: 10 | kind: Role 11 | name: task-manager-role 12 | apiGroup: rbac.authorization.k8s.io 13 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/task-manager/role.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: Role 3 | metadata: 4 | name: task-manager-role 5 | namespace: cornserve 6 | rules: 7 | - apiGroups: [""] 8 | resources: ["pods", "services"] 9 | verbs: ["*"] 10 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/base/task-manager/service-account.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ServiceAccount 3 | metadata: 4 | name: task-manager 5 | namespace: cornserve 6 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/dev/gateway/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - node-port-service.yaml 5 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/dev/gateway/node-port-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: gateway-node-port 5 | namespace: cornserve 6 | spec: 7 | type: NodePort 8 | selector: 9 | app: gateway 10 | ports: 11 | - name: gateway 12 | port: 8000 13 | targetPort: http 14 | nodePort: 30080 15 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/dev/image-pull-policy.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: gateway 5 | namespace: cornserve 6 | spec: 7 | template: 8 | spec: 9 | containers: 10 | - name: gateway 11 | imagePullPolicy: Always 12 | --- 13 | apiVersion: apps/v1 14 | kind: Deployment 15 | metadata: 16 | name: resource-manager 17 | namespace: cornserve 18 | spec: 19 | template: 20 | spec: 21 | containers: 22 | - name: resource-manager 23 | imagePullPolicy: Always 24 | --- 25 | apiVersion: apps/v1 26 | kind: Deployment 27 | metadata: 28 | name: task-dispatcher 29 | namespace: cornserve 30 | spec: 31 | template: 32 | spec: 33 | containers: 34 | - name: task-dispatcher 35 | imagePullPolicy: Always 36 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/dev/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - ../../base 5 | - gateway 6 | patches: 7 | - path: image-pull-policy.yaml 8 | images: 9 | - name: cornserve/gateway 10 | newName: localhost:5000/cornserve/gateway 11 | newTag: latest 12 | - name: cornserve/resource-manager 13 | newName: localhost:5000/cornserve/resource-manager 14 | newTag: latest 15 | - name: cornserve/sidecar 16 | newName: localhost:5000/cornserve/sidecar 17 | newTag: latest 18 | - name: cornserve/task-dispatcher 19 | newName: localhost:5000/cornserve/task-dispatcher 20 | newTag: latest 21 | - name: cornserve/task-manager 22 | newName: localhost:5000/cornserve/task-manager 23 | newTag: latest 24 | configMapGenerator: 25 | - name: cornserve-config 26 | namespace: cornserve 27 | behavior: merge 28 | literals: 29 | - CORNSERVE_IMAGE_PREFIX=localhost:5000/cornserve 30 | - CORNSERVE_IMAGE_PULL_POLICY=Always 31 | - CORNSERVE_IMAGE_TAG=latest 32 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/local/gateway/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - node-port-service.yaml 5 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/local/gateway/node-port-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: gateway-node-port 5 | namespace: cornserve 6 | spec: 7 | type: NodePort 8 | selector: 9 | app: gateway 10 | ports: 11 | - name: gateway 12 | port: 8000 13 | targetPort: http 14 | nodePort: 30080 15 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/local/image-pull-policy.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: gateway 5 | namespace: cornserve 6 | spec: 7 | template: 8 | spec: 9 | containers: 10 | - name: gateway 11 | imagePullPolicy: Never 12 | --- 13 | apiVersion: apps/v1 14 | kind: Deployment 15 | metadata: 16 | name: resource-manager 17 | namespace: cornserve 18 | spec: 19 | template: 20 | spec: 21 | containers: 22 | - name: resource-manager 23 | imagePullPolicy: Never 24 | --- 25 | apiVersion: apps/v1 26 | kind: Deployment 27 | metadata: 28 | name: task-dispatcher 29 | namespace: cornserve 30 | spec: 31 | template: 32 | spec: 33 | containers: 34 | - name: task-dispatcher 35 | imagePullPolicy: Never 36 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/local/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - ../../base 5 | - gateway 6 | patches: 7 | - path: image-pull-policy.yaml 8 | images: 9 | - name: cornserve/gateway 10 | newName: docker.io/cornserve/gateway 11 | newTag: latest 12 | - name: cornserve/resource-manager 13 | newName: docker.io/cornserve/resource-manager 14 | newTag: latest 15 | - name: cornserve/sidecar 16 | newName: docker.io/cornserve/sidecar 17 | newTag: latest 18 | - name: cornserve/task-dispatcher 19 | newName: docker.io/cornserve/task-dispatcher 20 | newTag: latest 21 | - name: cornserve/task-manager 22 | newName: docker.io/cornserve/task-manager 23 | newTag: latest 24 | configMapGenerator: 25 | - name: cornserve-config 26 | namespace: cornserve 27 | behavior: merge 28 | literals: 29 | - CORNSERVE_IMAGE_PREFIX=docker.io/cornserve 30 | - CORNSERVE_IMAGE_PULL_POLICY=Never 31 | - CORNSERVE_IMAGE_TAG=latest 32 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/minikube/gateway/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - node-port-service.yaml 5 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/minikube/gateway/node-port-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: gateway-node-port 5 | namespace: cornserve 6 | spec: 7 | type: NodePort 8 | selector: 9 | app: gateway 10 | ports: 11 | - name: gateway 12 | port: 8000 13 | targetPort: http 14 | nodePort: 30080 15 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/minikube/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - ../../base 5 | - gateway 6 | patches: 7 | - target: 8 | kind: Deployment 9 | patch: |- 10 | - op: remove 11 | path: /spec/template/spec/nodeSelector 12 | images: 13 | - name: cornserve/gateway 14 | newName: cornserve/gateway 15 | newTag: v0.0.1.post2 16 | - name: cornserve/resource-manager 17 | newName: cornserve/resource-manager 18 | newTag: v0.0.1.post2 19 | - name: cornserve/sidecar 20 | newName: cornserve/sidecar 21 | newTag: v0.0.1.post2 22 | - name: cornserve/task-dispatcher 23 | newName: cornserve/task-dispatcher 24 | newTag: v0.0.1.post2 25 | - name: cornserve/task-manager 26 | newName: cornserve/task-manager 27 | newTag: v0.0.1.post2 28 | configMapGenerator: 29 | - name: cornserve-config 30 | namespace: cornserve 31 | behavior: merge 32 | literals: 33 | - CORNSERVE_IMAGE_PREFIX=cornserve 34 | - CORNSERVE_IMAGE_PULL_POLICY=IfNotPresent 35 | - CORNSERVE_IMAGE_TAG=v0.0.1.post2 36 | -------------------------------------------------------------------------------- /kubernetes/kustomize/cornserve/overlays/prod/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - gateway 5 | - resource-manager 6 | - sidecar 7 | - task-dispatcher 8 | - task-manager 9 | - namespace.yaml 10 | images: 11 | - name: cornserve/gateway 12 | newName: docker.io/cornserve/gateway 13 | newTag: v0.0.1.post2 14 | - name: cornserve/resource-manager 15 | newName: docker.io/cornserve/resource-manager 16 | newTag: v0.0.1.post2 17 | - name: cornserve/sidecar 18 | newName: docker.io/cornserve/sidecar 19 | newTag: v0.0.1.post2 20 | - name: cornserve/task-dispatcher 21 | newName: docker.io/cornserve/task-dispatcher 22 | newTag: v0.0.1.post2 23 | - name: cornserve/task-manager 24 | newName: docker.io/cornserve/task-manager 25 | newTag: v0.0.1.post2 26 | configMapGenerator: 27 | - name: cornserve-config 28 | namespace: cornserve 29 | behavior: merge 30 | literals: 31 | - CORNSERVE_IMAGE_PREFIX=docker.io/cornserve 32 | - CORNSERVE_IMAGE_PULL_POLICY=IfNotPresent 33 | - CORNSERVE_IMAGE_TAG=v0.0.1.post2 34 | -------------------------------------------------------------------------------- /kubernetes/registry.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script spins up a Docker registry container. 4 | # You potentially want to tweak the volume path to match your setup. 5 | 6 | set -evo pipefail 7 | 8 | docker run -d \ 9 | -p 5000:5000 \ 10 | -e REGISTRY_STORAGE_DELETE_ENABLED=true \ 11 | --restart=always \ 12 | --name cornserve-registry \ 13 | --volume /data/cornserve/registry:/var/lib/registry \ 14 | registry:2 15 | -------------------------------------------------------------------------------- /kubernetes/set_registry.sh: -------------------------------------------------------------------------------- 1 | #!/usr/local/env bash 2 | 3 | # Usage: 4 | # 5 | # REGISTRY= bash kubernetes/set_registry.sh 6 | # 7 | # Update the private registry URL in the cornserve dev overlay kustomization file. 8 | # Override `KUSTOMIZATION_FILE` to change the file path. 9 | 10 | set -evo pipefail 11 | 12 | # Ensure the environment variable REGISTRY is set 13 | if [[ -z "${REGISTRY}" ]]; then 14 | echo "Error: REGISTRY environment variable is not set." 15 | exit 1 16 | fi 17 | 18 | KUSTOMIZATION_FILE=${KUSTOMIZATION_FILE:-"kubernetes/kustomize/cornserve/overlays/dev/kustomization.yaml"} 19 | if [[ ! -f "$KUSTOMIZATION_FILE" ]]; then 20 | echo "Error: Kustomization file '$KUSTOMIZATION_FILE' does not exist." 21 | exit 1 22 | fi 23 | 24 | K3S_REGISTRIES_FILE=${K3S_REGISTRIES_FILE:-"kubernetes/k3s/registries.yaml"} 25 | if [[ ! -f "$K3S_REGISTRIES_FILE" ]]; then 26 | echo "Error: K3s registries file '$K3S_REGISTRIES_FILE' does not exist." 27 | exit 1 28 | fi 29 | 30 | echo "Editing registry in $KUSTOMIZATION_FILE and $K3S_REGISTRIES_FILE to $REGISTRY" 31 | 32 | sed -i.bak -e "s#localhost:5000#$REGISTRY#g" "$KUSTOMIZATION_FILE" 33 | rm "$KUSTOMIZATION_FILE.bak" 34 | 35 | sed -i.bak -e "s#localhost:5000#$REGISTRY#g" "$K3S_REGISTRIES_FILE" 36 | rm "$K3S_REGISTRIES_FILE.bak" 37 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | # Project information 2 | site_name: Cornserve 3 | site_url: https://cornserve.ai 4 | site_author: Cornserve team 5 | site_description: Easy, fast, and scalable multimodal agentic AI 6 | edit_uri: "" 7 | 8 | # Repository 9 | repo_name: cornserve-ai/cornserve 10 | repo_url: https://github.com/cornserve-ai/cornserve 11 | 12 | # Copyright 13 | copyright: Copyright © 2025 Cornserve team 14 | 15 | # Theme configuration 16 | theme: 17 | name: material 18 | favicon: assets/img/favicon.png 19 | icon: 20 | repo: fontawesome/brands/github 21 | logo: material/rocket-launch-outline 22 | features: 23 | - content.code.copy 24 | - content.code.annotate 25 | - search.suggest 26 | - navigation.tabs 27 | - navigation.tabs.sticky 28 | - navigation.top 29 | - navigation.indexes 30 | - content.tooltips 31 | - announce.dismiss 32 | palette: 33 | - scheme: light 34 | primary: black 35 | accent: amber 36 | 37 | # MkDocs plugins 38 | plugins: 39 | - search 40 | - autorefs 41 | - social: 42 | enabled: !ENV [BUILD_SOCIAL_CARD, false] 43 | cards_dir: assets/img/social 44 | - mkdocs-video: 45 | is_video: True 46 | video_autoplay: True 47 | css_style: 48 | width: 60% 49 | - mkdocs-jupyter 50 | 51 | # Extensions 52 | markdown_extensions: 53 | - meta 54 | - abbr 55 | - admonition 56 | - attr_list 57 | - footnotes 58 | - md_in_html 59 | - pymdownx.superfences 60 | - pymdownx.snippets 61 | - pymdownx.details 62 | - pymdownx.critic 63 | - pymdownx.arithmatex: 64 | generic: true 65 | - pymdownx.emoji: 66 | emoji_index: !!python/name:material.extensions.emoji.twemoji 67 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 68 | - pymdownx.superfences: 69 | custom_fences: 70 | - name: mermaid 71 | class: mermaid 72 | format: !!python/name:pymdownx.superfences.fence_code_format 73 | - pymdownx.highlight 74 | - pymdownx.inlinehilite 75 | 76 | # Page tree 77 | nav: 78 | - Cornserve: index.md 79 | - Getting Started: 80 | - getting_started/index.md 81 | - Deploying Cornserve: getting_started/cornserve.md 82 | - Building Apps: getting_started/building_apps.md 83 | - Using Jupyter Notebook: getting_started/jupyter.ipynb 84 | - Registering and Invoking Apps: getting_started/registering_apps.md 85 | - Architecture: 86 | - architecture/index.md 87 | - Task: architecture/task.md 88 | - Sidecar: architecture/sidecar.md 89 | - Eric: architecture/eric.md 90 | - Contributor Guide: 91 | - contributor_guide/index.md 92 | - Developing on Kubernetes: contributor_guide/kubernetes.md 93 | - Eric: contributor_guide/eric.md 94 | - Sidecar: contributor_guide/sidecar.md 95 | - Tracing: contributor_guide/tracing.md 96 | 97 | # Exclude file list 98 | exclude_docs: | 99 | requirements.txt 100 | 101 | # For Mathjax 102 | extra_javascript: 103 | - assets/js/mathjax.js 104 | - https://polyfill.io/v3/polyfill.min.js?features=es6 105 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 106 | 107 | # Extra stuff 108 | extra: 109 | analytics: 110 | provider: !ENV SITE_ANALYTICS 111 | property: G-8YY3G9ZZW5 112 | social: 113 | - name: Cornserve GitHub repository 114 | icon: fontawesome/brands/github 115 | link: https://github.com/cornserve-ai/cornserve 116 | 117 | extra_css: 118 | - assets/css/extra.css 119 | -------------------------------------------------------------------------------- /proto/v1/common.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package cornserve.common; 4 | 5 | // Whether something was successful or not 6 | enum Status { 7 | STATUS_UNSPECIFIED = 0; 8 | STATUS_OK = 1; 9 | STATUS_ERROR = 2; 10 | } 11 | 12 | // Concrete task instantiated from a unit task class. 13 | message UnitTask { 14 | // UnitTask Python class name. 15 | string task_class_name = 1; 16 | 17 | // JSON-serialized task object. 18 | string task_config = 2; 19 | } 20 | -------------------------------------------------------------------------------- /proto/v1/resource_manager.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package cornserve.resource_manager; 4 | 5 | import "common.proto"; 6 | 7 | service ResourceManager { 8 | // Deploy a new unit task 9 | rpc DeployUnitTask(DeployUnitTaskRequest) returns (DeployUnitTaskResponse); 10 | 11 | // Tear down a unit task 12 | rpc TeardownUnitTask(TeardownUnitTaskRequest) returns (TeardownUnitTaskResponse); 13 | 14 | // Health checking 15 | rpc Healthcheck(HealthcheckRequest) returns (HealthcheckResponse); 16 | } 17 | 18 | // Deploy a new task 19 | message DeployUnitTaskRequest { 20 | // Task to deploy 21 | common.UnitTask task = 1; 22 | } 23 | 24 | message DeployUnitTaskResponse { 25 | // Success or failure 26 | common.Status status = 1; 27 | } 28 | 29 | // Tear down a task 30 | message TeardownUnitTaskRequest { 31 | // Task to tear down 32 | common.UnitTask task = 1; 33 | } 34 | 35 | message TeardownUnitTaskResponse { 36 | // Success or failure 37 | common.Status status = 1; 38 | } 39 | 40 | // Health checking 41 | message TaskManagerStatus { 42 | common.UnitTask task = 1; 43 | common.Status status = 2; 44 | } 45 | 46 | message HealthcheckRequest {} 47 | 48 | message HealthcheckResponse { 49 | common.Status status = 1; 50 | repeated TaskManagerStatus task_manager_statuses = 2; 51 | } 52 | -------------------------------------------------------------------------------- /proto/v1/sidecar.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package cornserve.sidecar; 4 | 5 | import "common.proto"; 6 | 7 | service Sidecar { 8 | rpc Register(RegisterRequest) returns (RegisterResponse); 9 | rpc Send(SendRequest) returns (SendResponse); 10 | rpc Receive(ReceiveRequest) returns (ReceiveResponse); 11 | rpc MarkDone(MarkDoneRequest) returns (MarkDoneResponse); 12 | rpc Unlink(UnlinkRequest) returns (UnlinkResponse); 13 | rpc PrepareReceive(PrepareReceiveRequest) returns (PrepareReceiveResponse); 14 | 15 | rpc CheckHealth(CheckHealthRequest) returns (CheckHealthResponse); 16 | // TODO: add unregister 17 | } 18 | 19 | message RegisterRequest { 20 | int32 rank = 1; 21 | repeated int32 group = 2; 22 | string dtype = 3; 23 | int32 send_slot_numel = 4; 24 | int32 recv_slot_numel = 5; 25 | bool concurrent_copy = 6; 26 | } 27 | 28 | message RegisterResponse { 29 | common.Status status = 1; 30 | int64 shm_size = 2; // numel in the single sender/receiver slab 31 | int32 local_rank = 3; // the GPU index to use 32 | int32 num_local_sidecars = 4; // used for init_shmem 33 | } 34 | 35 | message RankGroup { 36 | repeated int32 ranks = 1; 37 | } 38 | 39 | message SendRequest { 40 | string id = 1; 41 | repeated RankGroup dst_ranks = 2; 42 | int32 shard_rank = 3; // tp rank 43 | bytes data = 4; // serialized obj 44 | int32 chunk_id = 5; 45 | int32 num_chunks = 6; 46 | } 47 | 48 | message SendResponse { 49 | common.Status status = 1; 50 | } 51 | 52 | message ReceiveRequest { 53 | string id = 1; 54 | int32 chunk_id = 2; 55 | } 56 | 57 | message ReceiveResponse { 58 | common.Status status = 1; 59 | bytes data = 2; 60 | } 61 | 62 | message MarkDoneRequest { 63 | string id = 1; 64 | int32 chunk_id = 2; 65 | int32 shard_rank = 3; // tp rank 66 | } 67 | 68 | message MarkDoneResponse { 69 | common.Status status = 1; 70 | } 71 | 72 | message UnlinkRequest { 73 | string id = 1; 74 | int32 chunk_id = 2; 75 | } 76 | 77 | message UnlinkResponse { 78 | common.Status status = 1; 79 | } 80 | 81 | 82 | message PrepareReceiveRequest { 83 | string id = 1; 84 | bytes data = 2; // msgpack encoded handle 85 | int32 src_rank = 3; 86 | int32 chunk_id = 4; 87 | int32 num_chunks = 5; 88 | } 89 | 90 | message PrepareReceiveResponse { 91 | common.Status status = 1; 92 | } 93 | 94 | enum HealthStatus { 95 | HEALTH_ALL_GOOD = 0; 96 | HEALTH_MEMORY_PRESSURE = 1; 97 | // This is for a revived or uninitialized sidecar. 98 | HEALTH_OFFLINE = 2; 99 | } 100 | 101 | message CheckHealthRequest { 102 | } 103 | 104 | message CheckHealthResponse { 105 | HealthStatus status = 1; 106 | } 107 | 108 | message ReportMemoryRequest { 109 | int32 pressure = 1; 110 | } 111 | 112 | message ReportMemoryResponse { 113 | common.Status status = 1; 114 | } 115 | -------------------------------------------------------------------------------- /proto/v1/task_dispatcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package cornserve.task_dispatcher; 4 | 5 | import "common.proto"; 6 | 7 | service TaskDispatcher { 8 | // New unit task deployed 9 | rpc NotifyUnitTaskDeployment(NotifyUnitTaskDeploymentRequest) returns (NotifyUnitTaskDeploymentResponse); 10 | 11 | // Existing unit task removed 12 | rpc NotifyUnitTaskTeardown(NotifyUnitTaskTeardownRequest) returns (NotifyUnitTaskTeardownResponse); 13 | } 14 | 15 | // New unit task deployed 16 | message TaskManagerDeployment { 17 | // Task manager URL 18 | string url = 1; 19 | } 20 | 21 | message NotifyUnitTaskDeploymentRequest { 22 | // Task that was deployed 23 | common.UnitTask task = 1; 24 | 25 | // Task manager deployment info 26 | TaskManagerDeployment task_manager = 2; 27 | } 28 | 29 | message NotifyUnitTaskDeploymentResponse { 30 | common.Status status = 1; 31 | } 32 | 33 | // Existing unit task removed 34 | message NotifyUnitTaskTeardownRequest { 35 | common.UnitTask task = 1; 36 | } 37 | 38 | message NotifyUnitTaskTeardownResponse { 39 | common.Status status = 1; 40 | } 41 | -------------------------------------------------------------------------------- /proto/v1/task_manager.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package cornserve.task_manager; 4 | 5 | import "common.proto"; 6 | 7 | service TaskManager { 8 | // Configure the task manager to handle a task 9 | rpc RegisterTask(RegisterTaskRequest) returns (RegisterTaskResponse); 10 | 11 | // Add or remove resources from a task manager 12 | rpc UpdateResources(UpdateResourcesRequest) returns (UpdateResourcesResponse); 13 | 14 | // Shutdown the task manager 15 | rpc Shutdown(ShutdownRequest) returns (ShutdownResponse); 16 | 17 | // Load management 18 | rpc GetTaskProfile(GetTaskProfileRequest) returns (GetTaskProfileResponse); 19 | rpc ReconcileTargetLoad(ReconcileTargetLoadRequest) returns (ReconcileTargetLoadResponse); 20 | 21 | // Request routing 22 | rpc GetRoute(GetRouteRequest) returns (GetRouteResponse); 23 | 24 | // Health checking 25 | rpc Healthcheck(HealthcheckRequest) returns (HealthcheckResponse); 26 | } 27 | 28 | enum ResourceAction { 29 | // Give more resources 30 | ADD = 0; 31 | 32 | // Take away resources 33 | REMOVE = 1; 34 | } 35 | 36 | message GPUResource { 37 | // Whether to add or remove this resource 38 | ResourceAction action = 1; 39 | 40 | // Node ID of the GPU 41 | string node_id = 2; 42 | 43 | // Global rank of the GPU 44 | int32 global_rank = 3; 45 | 46 | // Local rank of the GPU 47 | int32 local_rank = 4; 48 | } 49 | 50 | message RegisterTaskRequest { 51 | // ID of the task manager 52 | string task_manager_id = 1; 53 | 54 | // Unit task instance 55 | common.UnitTask task = 2; 56 | 57 | // Initial set of GPU resources 58 | repeated GPUResource gpus = 3; 59 | } 60 | 61 | message RegisterTaskResponse { 62 | common.Status status = 1; 63 | } 64 | 65 | // Update resources 66 | message UpdateResourcesRequest { 67 | // ID of the task manager 68 | string task_manager_id = 1; 69 | 70 | // Resources to add or remove 71 | repeated GPUResource gpus = 2; 72 | } 73 | 74 | message UpdateResourcesResponse { 75 | common.Status status = 1; 76 | } 77 | 78 | // Shutdown 79 | message ShutdownRequest {} 80 | 81 | message ShutdownResponse { 82 | common.Status status = 1; 83 | } 84 | 85 | // Load management 86 | message ReconcileTargetLoadRequest { 87 | string task_id = 1; 88 | float target_load = 2; 89 | } 90 | 91 | message ReconcileTargetLoadResponse { 92 | common.Status status = 1; 93 | string message = 2; 94 | } 95 | 96 | // Task profiling 97 | message ProfilePoint { 98 | int32 num_gpus = 1; 99 | float max_sustainable_load = 2; 100 | DeploymentConfig deployment_config = 3; 101 | } 102 | 103 | message DeploymentConfig { 104 | int32 num_replicas = 1; 105 | int32 tensor_parallel_degree = 2; 106 | int32 pipeline_parallel_degree = 3; 107 | repeated string gpu_assignments = 4; 108 | } 109 | 110 | message GetTaskProfileRequest { 111 | string task_id = 1; 112 | } 113 | 114 | message GetTaskProfileResponse { 115 | repeated ProfilePoint profile_points = 1; 116 | } 117 | 118 | // Request routing 119 | message GetRouteRequest { 120 | // ID of the request 121 | string request_id = 1; 122 | 123 | // Optional routing hint 124 | // e.g., hash of image URL, system prompt 125 | optional string routing_hint = 2; 126 | } 127 | 128 | message GetRouteResponse { 129 | // URL of the task executor to route the request to 130 | string task_executor_url = 1; 131 | 132 | // Sidecar ranks the task executor is registered with 133 | repeated int32 sidecar_ranks = 2; 134 | } 135 | 136 | // Healthcheck response 137 | message TaskExecutorStatus { 138 | common.Status status = 1; 139 | repeated int32 sidecar_ranks = 2; 140 | } 141 | 142 | message HealthcheckRequest {} 143 | 144 | message HealthcheckResponse { 145 | common.Status status = 1; 146 | map task_executor_statuses = 2; 147 | } 148 | -------------------------------------------------------------------------------- /python/cornserve/__init__.py: -------------------------------------------------------------------------------- 1 | """Easy, fast, and scalable multimodal agentic AI.""" 2 | 3 | __version__ = "0.0.1.post2" 4 | -------------------------------------------------------------------------------- /python/cornserve/app/__init__.py: -------------------------------------------------------------------------------- 1 | """Cornserve application package.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/app/base.py: -------------------------------------------------------------------------------- 1 | """Base classes for cornserve applications.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import ClassVar 6 | 7 | from pydantic import BaseModel, ConfigDict, Field 8 | 9 | from cornserve.task.base import Task 10 | 11 | 12 | class AppRequest(BaseModel): 13 | """Base class for application requests. 14 | 15 | All user-defined request classes must inherit from this. 16 | """ 17 | 18 | 19 | class AppResponse(BaseModel): 20 | """Base class for application responses. 21 | 22 | All user-defined response classes must inherit from this. 23 | """ 24 | 25 | 26 | class AppConfig(BaseModel): 27 | """Base class for application configuration. 28 | 29 | All user-defined config classes must inherit from this. 30 | """ 31 | 32 | tasks: ClassVar[dict[str, Task]] = Field( 33 | default_factory=dict, 34 | description="Dictionary of tasks that the app requires.", 35 | ) 36 | 37 | model_config = ConfigDict(extra="forbid") 38 | -------------------------------------------------------------------------------- /python/cornserve/constants.py: -------------------------------------------------------------------------------- 1 | """Constants used throughout Cornserve. 2 | 3 | Environment variables expected: 4 | - `CORNSERVE_IMAGE_PREFIX`: Docker image prefix (default: "docker.io/cornserve") 5 | - `CORNSERVE_IMAGE_TAG`: Docker image tag (default: "latest") 6 | - `CORNSERVE_IMAGE_PULL_POLICY`: Docker image pull policy (default: "IfNotPresent") 7 | 8 | These environment variables are set by different Kustomize overlays depending on 9 | the deployment context (e.g., local, dev, prod). 10 | """ 11 | 12 | import os 13 | import warnings 14 | from typing import TYPE_CHECKING, Any 15 | 16 | 17 | def _get_env_warn_default(var_name: str, default: str) -> str: 18 | """Get environment variable with a warning if not set, returning a default value.""" 19 | try: 20 | return os.environ[var_name] 21 | except KeyError: 22 | warnings.warn( 23 | f"Environment variable {var_name} not set, using default '{default}'.", 24 | stacklevel=2, 25 | ) 26 | return default 27 | 28 | 29 | def _build_image_name(name: str) -> str: 30 | """Builds a full image name with prefix, tag, and pull policy.""" 31 | image_prefix = _get_env_warn_default("CORNSERVE_IMAGE_PREFIX", "docker.io/cornserve").strip("/") 32 | image_tag = _get_env_warn_default("CORNSERVE_IMAGE_TAG", "latest") 33 | return f"{image_prefix}/{name}:{image_tag}" 34 | 35 | 36 | # Cache for lazy-loaded constants 37 | _lazy_cache = {} 38 | 39 | # Define which constants should be lazily loaded 40 | _LAZY_CONSTANTS = { 41 | "CONTAINER_IMAGE_TASK_MANAGER": lambda: _build_image_name("task-manager"), 42 | "CONTAINER_IMAGE_SIDECAR": lambda: _build_image_name("sidecar"), 43 | "CONTAINER_IMAGE_ERIC": lambda: _build_image_name("eric"), 44 | "CONTAINER_IMAGE_VLLM": lambda: _build_image_name("vllm"), 45 | "CONTAINER_IMAGE_PULL_POLICY": lambda: _get_env_warn_default("CORNSERVE_IMAGE_PULL_POLICY", "IfNotPresent"), 46 | } 47 | 48 | 49 | def __getattr__(name: str) -> Any: 50 | """Module-level __getattr__ for lazy loading of image-related constants.""" 51 | if name in _LAZY_CONSTANTS: 52 | if name not in _lazy_cache: 53 | _lazy_cache[name] = _LAZY_CONSTANTS[name]() 54 | return _lazy_cache[name] 55 | raise AttributeError(f"module '{__name__}' has no attribute '{name}'") 56 | 57 | 58 | # Kubernetes resources. 59 | K8S_NAMESPACE = "cornserve" 60 | K8S_CORNSERVE_CONFIG_MAP_NAME = "cornserve-config" 61 | K8S_SIDECAR_SERVICE_NAME = "sidecar" 62 | K8S_GATEWAY_SERVICE_HTTP_URL = "http://gateway:8000" 63 | K8S_TASK_DISPATCHER_HTTP_URL = "http://task-dispatcher:8000" 64 | K8S_TASK_DISPATCHER_GRPC_URL = "task-dispatcher:50051" 65 | K8S_RESOURCE_MANAGER_GRPC_URL = "resource-manager:50051" 66 | K8S_OTEL_GRPC_URL = "http://jaeger-collector.cornserve-system.svc.cluster.local:4317" 67 | K8S_TASK_EXECUTOR_SECRET_NAME = "cornserve-env" 68 | K8S_TASK_EXECUTOR_HF_TOKEN_KEY = "hf-token" 69 | K8S_TASK_EXECUTOR_HEALTHY_TIMEOUT = 20 * 60.0 70 | 71 | # Volume host paths. 72 | VOLUME_HF_CACHE = "/data/hfcache" 73 | VOLUME_SHM = "/dev/shm" 74 | 75 | # Container images name construction. 76 | if TYPE_CHECKING: 77 | CONTAINER_IMAGE_TASK_MANAGER: str 78 | CONTAINER_IMAGE_SIDECAR: str 79 | CONTAINER_IMAGE_ERIC: str 80 | CONTAINER_IMAGE_VLLM: str 81 | CONTAINER_IMAGE_PULL_POLICY: str 82 | -------------------------------------------------------------------------------- /python/cornserve/logging.py: -------------------------------------------------------------------------------- 1 | """Logging utilities for the Cornserve project.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | import os 7 | import sys 8 | from collections.abc import MutableMapping 9 | from typing import Any 10 | 11 | 12 | def get_logger( 13 | name: str, adapters: list[type[logging.LoggerAdapter]] | None = None 14 | ) -> logging.Logger | logging.LoggerAdapter: 15 | """Get a logger with the given name with some formatting configs.""" 16 | # No need to reconfigure the logger if it was already created 17 | if name in logging.Logger.manager.loggerDict: 18 | return logging.getLogger(name) 19 | 20 | logger = logging.getLogger(name) 21 | logger.setLevel(os.environ.get("CORNSERVE_LOG_LEVEL", logging.INFO)) 22 | formatter = logging.Formatter("%(levelname)s %(asctime)s [%(name)s:%(lineno)d] %(message)s") 23 | handler = logging.StreamHandler(sys.stderr) 24 | handler.setFormatter(formatter) 25 | logger.addHandler(handler) 26 | if adapters is not None: 27 | for adapter in adapters: 28 | logger = adapter(logger) 29 | return logger 30 | 31 | 32 | class SidcarAdapter(logging.LoggerAdapter): 33 | """Adapter that prepends 'Sidecar {rank}' to all messages.""" 34 | 35 | def __init__(self, logger: logging.Logger) -> None: 36 | """Initialize the adapter with the given logger.""" 37 | super().__init__(logger, {}) 38 | self.sidecar_rank = int(os.environ.get("SIDECAR_RANK", "-1")) 39 | assert self.sidecar_rank >= 0, "SIDECAR_RANK or SIDECAR_POD_NAME must be set." 40 | 41 | def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple: 42 | """Prepend 'Sidecar {rank}' to the message.""" 43 | return f"Sidecar {self.sidecar_rank}: {msg}", kwargs 44 | -------------------------------------------------------------------------------- /python/cornserve/services/__init__.py: -------------------------------------------------------------------------------- 1 | """Services that comprise Cornserve.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/services/gateway/__init__.py: -------------------------------------------------------------------------------- 1 | """The gateway service is the main entry point for Cornserve. 2 | 3 | System admins can use the gateway service to register and unregister applications. 4 | Successfully registered applications will be deployed to the cluster and made available 5 | for invocation from users. 6 | """ 7 | -------------------------------------------------------------------------------- /python/cornserve/services/gateway/app/models.py: -------------------------------------------------------------------------------- 1 | """Type definitions for the App Manager.""" 2 | 3 | from __future__ import annotations 4 | 5 | import enum 6 | from collections.abc import Callable, Coroutine 7 | from dataclasses import dataclass 8 | from types import ModuleType 9 | 10 | from cornserve.app.base import AppConfig, AppRequest, AppResponse 11 | 12 | 13 | class AppState(enum.StrEnum): 14 | """Possible states of a registered app.""" 15 | 16 | NOT_READY = "not ready" 17 | READY = "ready" 18 | 19 | 20 | @dataclass 21 | class AppClasses: 22 | """Container for a registered app. 23 | 24 | Attributes: 25 | request_cls: The class that defines the app's request schema. 26 | response_cls: The class that defines the app's response schema. 27 | config_cls: The class that specifies the app's configuration. 28 | serve_fn: The function that implements the app's logic. 29 | """ 30 | 31 | request_cls: type[AppRequest] 32 | response_cls: type[AppResponse] 33 | config_cls: type[AppConfig] 34 | serve_fn: Callable[[AppRequest], Coroutine[None, None, AppResponse]] 35 | 36 | 37 | @dataclass 38 | class AppDefinition: 39 | """Full definition of a registered app. 40 | 41 | Attributes: 42 | app_id: The ID of the app. 43 | module: The module that contains the app's code. 44 | source_code: The Python source code of the app. 45 | classes: The classes that define the app's schema and logic. 46 | """ 47 | 48 | app_id: str 49 | module: ModuleType 50 | source_code: str 51 | classes: AppClasses 52 | -------------------------------------------------------------------------------- /python/cornserve/services/gateway/entrypoint.py: -------------------------------------------------------------------------------- 1 | """Spins up the Gateway service.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import os 7 | import signal 8 | from typing import TYPE_CHECKING 9 | 10 | import uvicorn 11 | from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor 12 | from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorClient 13 | from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor 14 | 15 | from cornserve.logging import get_logger 16 | from cornserve.services.gateway.router import create_app 17 | from cornserve.tracing import configure_otel 18 | 19 | if TYPE_CHECKING: 20 | from cornserve.services.gateway.app.manager import AppManager 21 | 22 | logger = get_logger("cornserve.services.gateway.entrypoint") 23 | 24 | 25 | async def serve() -> None: 26 | """Serve the Gateway as a FastAPI app.""" 27 | logger.info("Starting Gateway service") 28 | 29 | configure_otel("gateway") 30 | 31 | app = create_app() 32 | FastAPIInstrumentor.instrument_app(app) 33 | GrpcAioInstrumentorClient().instrument() 34 | HTTPXClientInstrumentor().instrument() 35 | 36 | logger.info("Available routes are:") 37 | for route in app.routes: 38 | methods = getattr(route, "methods", None) 39 | path = getattr(route, "path", None) 40 | 41 | if methods is None or path is None: 42 | continue 43 | 44 | logger.info( 45 | "%s %s", 46 | list(methods)[0] if len(methods) == 1 else "{" + ",".join(methods) + "}", 47 | path, 48 | ) 49 | 50 | config = uvicorn.Config(app, host="0.0.0.0", port=8000) 51 | server = uvicorn.Server(config) 52 | app_manager: AppManager = app.state.app_manager 53 | 54 | # `TaskContext` reads this environment variable to determine the URL of the Gateway. 55 | os.environ["CORNSERVE_GATEWAY_URL"] = "http://localhost:8000" 56 | 57 | loop = asyncio.get_running_loop() 58 | server_task = loop.create_task(server.serve()) 59 | 60 | def shutdown() -> None: 61 | server_task.cancel() 62 | 63 | loop.add_signal_handler(signal.SIGINT, shutdown) 64 | loop.add_signal_handler(signal.SIGTERM, shutdown) 65 | 66 | try: 67 | await server_task 68 | except asyncio.CancelledError: 69 | logger.info("Shutting down Gateway service") 70 | await app_manager.shutdown() 71 | await server.shutdown() 72 | 73 | 74 | if __name__ == "__main__": 75 | asyncio.run(serve()) 76 | -------------------------------------------------------------------------------- /python/cornserve/services/gateway/models.py: -------------------------------------------------------------------------------- 1 | """Gateway request and response models.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | from pydantic import BaseModel 8 | 9 | 10 | class AppRegistrationRequest(BaseModel): 11 | """Request for registering a new application. 12 | 13 | Attributes: 14 | source_code: The Python source code of the application. 15 | """ 16 | 17 | source_code: str 18 | 19 | 20 | class AppRegistrationResponse(BaseModel): 21 | """Response for registering a new application. 22 | 23 | Attributes: 24 | app_id: The unique identifier for the registered application. 25 | """ 26 | 27 | app_id: str 28 | 29 | 30 | class AppInvocationRequest(BaseModel): 31 | """Request for invoking a registered application. 32 | 33 | Attributes: 34 | request_data: The input data for the application. Should be a valid 35 | JSON object that matches the `Request` schema of the application. 36 | """ 37 | 38 | request_data: dict[str, Any] 39 | -------------------------------------------------------------------------------- /python/cornserve/services/gateway/session.py: -------------------------------------------------------------------------------- 1 | """A session manager for the Cornserve gateway.""" 2 | 3 | import asyncio 4 | import uuid 5 | from dataclasses import dataclass, field 6 | 7 | from cornserve.frontend import TaskRequest, TaskRequestVerb, TaskResponse 8 | from cornserve.logging import get_logger 9 | from cornserve.services.gateway.task_manager import TaskManager 10 | from cornserve.task.base import UnitTask 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | @dataclass 16 | class Session: 17 | """The session state. 18 | 19 | Attributes: 20 | tasks: A dictionary of tasks that are currently in use by this session. 21 | """ 22 | 23 | tasks: dict[str, UnitTask] = field(default_factory=dict) 24 | 25 | 26 | class SessionManager: 27 | """Manages debug sessions for the Cornserve gateway.""" 28 | 29 | def __init__(self, task_manager: TaskManager) -> None: 30 | """Initialize the session manager.""" 31 | self.task_manager = task_manager 32 | self.lock = asyncio.Lock() 33 | self.sessions: dict[str, Session] = {} 34 | 35 | async def create_session(self) -> str: 36 | """Create a new session.""" 37 | session_id = str(uuid.uuid4()) 38 | async with self.lock: 39 | while session_id in self.sessions: 40 | session_id = str(uuid.uuid4()) 41 | self.sessions[session_id] = Session() 42 | logger.info("Created session with ID: %s", session_id) 43 | return session_id 44 | 45 | async def handle_request(self, session_id: str, request: dict) -> TaskResponse: 46 | """Handle a request for a session. 47 | 48 | Args: 49 | session_id: The ID of the session. 50 | request: The request data. 51 | """ 52 | async with self.lock: 53 | if session_id not in self.sessions: 54 | logger.warning("Session ID %s not found", session_id) 55 | return TaskResponse(status=404, content="Session invalid") 56 | try: 57 | task_request = TaskRequest.model_validate(request) 58 | except Exception: 59 | logger.exception("Invalid request") 60 | return TaskResponse(status=400, content="Invalid request") 61 | if task_request.verb == TaskRequestVerb.DECLARE_USED: 62 | logger.info("Declaring tasks as used: %s", task_request.task_list) 63 | await self.task_manager.declare_used(task_request.get_tasks()) 64 | self.sessions[session_id].tasks.update({task.id: task for task in task_request.get_tasks()}) 65 | return TaskResponse(status=200, content="Tasks declared used") 66 | elif task_request.verb == TaskRequestVerb.DECLARE_NOT_USED: 67 | logger.info("Declaring tasks as not used: %s", task_request.task_list) 68 | await self.task_manager.declare_not_used(task_request.get_tasks()) 69 | for task in task_request.get_tasks(): 70 | if task.id in self.sessions[session_id].tasks: 71 | del self.sessions[session_id].tasks[task.id] 72 | return TaskResponse(status=200, content="Tasks declared not used") 73 | elif task_request.verb == TaskRequestVerb.HEARTBEAT: 74 | return TaskResponse(status=200, content="Session is alive") 75 | else: 76 | logger.warning("Unknown method %s", task_request.verb) 77 | return TaskResponse(status=400, content="Unknown method") 78 | 79 | async def destroy_session(self, session_id: str) -> bool: 80 | """Destroy a session. Clean up all tasks in use by this session. 81 | 82 | Args: 83 | session_id: The ID of the session to destroy. 84 | """ 85 | async with self.lock: 86 | if session_id in self.sessions: 87 | logger.info("Destroying session with ID: %s", session_id) 88 | tasks = list(self.sessions[session_id].tasks.values()) 89 | await self.task_manager.declare_not_used(tasks) 90 | del self.sessions[session_id] 91 | return True 92 | return False 93 | -------------------------------------------------------------------------------- /python/cornserve/services/resource_manager/__init__.py: -------------------------------------------------------------------------------- 1 | """The Resource Manager manages resources (GPUs) and task managers. 2 | 3 | It allocates resources to task managers and spawns and kills them. 4 | It also reconciles the state of task managers with the task dispatcher. 5 | """ 6 | -------------------------------------------------------------------------------- /python/cornserve/services/resource_manager/entrypoint.py: -------------------------------------------------------------------------------- 1 | """Entrypoint for the Resource Manager service.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import signal 7 | 8 | from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorClient, GrpcAioInstrumentorServer 9 | 10 | from cornserve.logging import get_logger 11 | from cornserve.services.resource_manager.grpc import create_server 12 | from cornserve.services.resource_manager.manager import ResourceManager 13 | from cornserve.tracing import configure_otel 14 | 15 | logger = get_logger("cornserve.services.resource_manager.entrypoint") 16 | 17 | 18 | async def serve() -> None: 19 | """Start the gRPC server.""" 20 | configure_otel("resource_manager") 21 | 22 | GrpcAioInstrumentorServer().instrument() 23 | GrpcAioInstrumentorClient().instrument() 24 | 25 | resource_manager = await ResourceManager.init() 26 | 27 | server = create_server(resource_manager) 28 | await server.start() 29 | 30 | logger.info("gRPC server started") 31 | 32 | loop = asyncio.get_running_loop() 33 | server_task = loop.create_task(server.wait_for_termination()) 34 | 35 | def shutdown() -> None: 36 | server_task.cancel() 37 | 38 | loop.add_signal_handler(signal.SIGINT, shutdown) 39 | loop.add_signal_handler(signal.SIGTERM, shutdown) 40 | 41 | try: 42 | await server_task 43 | except asyncio.CancelledError: 44 | logger.info("Shutting down Resource Manager service") 45 | await server.stop(5) 46 | logger.info("Shutting down resource manager...") 47 | await resource_manager.shutdown() 48 | logger.info("Resource Manager service shutdown complete") 49 | 50 | 51 | if __name__ == "__main__": 52 | asyncio.run(serve()) 53 | -------------------------------------------------------------------------------- /python/cornserve/services/resource_manager/grpc.py: -------------------------------------------------------------------------------- 1 | """Resource Manager gRPC server.""" 2 | 3 | from __future__ import annotations 4 | 5 | import grpc 6 | 7 | from cornserve.logging import get_logger 8 | from cornserve.services.pb import common_pb2, resource_manager_pb2, resource_manager_pb2_grpc 9 | from cornserve.services.resource_manager.manager import ResourceManager 10 | from cornserve.task.base import UnitTask 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | class ResourceManagerServicer(resource_manager_pb2_grpc.ResourceManagerServicer): 16 | """Resource Manager gRPC service implementation.""" 17 | 18 | def __init__(self, manager: ResourceManager) -> None: 19 | """Initialize the ResourceManagerServicer.""" 20 | self.manager = manager 21 | 22 | async def DeployUnitTask( 23 | self, 24 | request: resource_manager_pb2.DeployUnitTaskRequest, 25 | context: grpc.aio.ServicerContext, 26 | ) -> resource_manager_pb2.DeployUnitTaskResponse: 27 | """Deploy a unit task in the cluster.""" 28 | await self.manager.deploy_unit_task(UnitTask.from_pb(request.task)) 29 | return resource_manager_pb2.DeployUnitTaskResponse(status=common_pb2.Status.STATUS_OK) 30 | 31 | async def TeardownUnitTask( 32 | self, 33 | request: resource_manager_pb2.TeardownUnitTaskRequest, 34 | context: grpc.aio.ServicerContext, 35 | ) -> resource_manager_pb2.TeardownUnitTaskResponse: 36 | """Reconcile a removed app by shutting down task managers if needed.""" 37 | await self.manager.teardown_unit_task(UnitTask.from_pb(request.task)) 38 | return resource_manager_pb2.TeardownUnitTaskResponse(status=common_pb2.Status.STATUS_OK) 39 | 40 | async def Healthcheck( 41 | self, 42 | request: resource_manager_pb2.HealthcheckRequest, 43 | context: grpc.aio.ServicerContext, 44 | ) -> resource_manager_pb2.HealthcheckResponse: 45 | """Recursively check and report the health of task managers.""" 46 | try: 47 | overall_status, task_manager_statuses = await self.manager.healthcheck() 48 | 49 | status_map = { 50 | True: common_pb2.Status.STATUS_OK, 51 | False: common_pb2.Status.STATUS_ERROR, 52 | } 53 | 54 | # Convert the statuses into proto message format 55 | proto_statuses = [ 56 | resource_manager_pb2.TaskManagerStatus(task=task.to_pb(), status=status_map[status]) 57 | for task, status in task_manager_statuses 58 | ] 59 | 60 | return resource_manager_pb2.HealthcheckResponse( 61 | status=status_map[overall_status], task_manager_statuses=proto_statuses 62 | ) 63 | except Exception as e: 64 | logger.exception("Healthcheck failed: %s", e) 65 | return resource_manager_pb2.HealthcheckResponse( 66 | status=common_pb2.Status.STATUS_ERROR, task_manager_statuses=[] 67 | ) 68 | 69 | 70 | def create_server(resource_manager: ResourceManager) -> grpc.aio.Server: 71 | """Create the gRPC server for the Resource Manager.""" 72 | servicer = ResourceManagerServicer(resource_manager) 73 | server = grpc.aio.server() 74 | resource_manager_pb2_grpc.add_ResourceManagerServicer_to_server(servicer, server) 75 | listen_addr = "[::]:50051" 76 | server.add_insecure_port(listen_addr) 77 | logger.info("Starting server on %s", listen_addr) 78 | return server 79 | -------------------------------------------------------------------------------- /python/cornserve/services/sidecar/__init__.py: -------------------------------------------------------------------------------- 1 | """Sidecar servers that forward data between peers.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/services/sidecar/scheduler.py: -------------------------------------------------------------------------------- 1 | """Template Scheduler for Sidecar.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import contextlib 7 | from collections.abc import Callable 8 | from dataclasses import dataclass 9 | from types import TracebackType 10 | from typing import Any 11 | 12 | 13 | @dataclass 14 | class _Job: 15 | fn: Callable[..., Any] 16 | args: tuple 17 | kwargs: dict 18 | fut: asyncio.Future 19 | 20 | 21 | # ── async no‑op context manager ───────────────────────────────────── 22 | class _AsyncNullCM: 23 | """`async with _ASYNC_NULL:` does nothing (like contextlib.nullcontext).""" 24 | 25 | async def __aenter__(self) -> None: # noqa: D401 26 | return None 27 | 28 | async def __aexit__( 29 | self, 30 | exc_type: type[BaseException] | None, 31 | exc: BaseException | None, 32 | tb: TracebackType | None, 33 | ) -> bool: 34 | return False 35 | 36 | 37 | _ASYNC_NULL = _AsyncNullCM() 38 | 39 | 40 | class Scheduler: 41 | """Central launch‑controller.""" 42 | 43 | def __init__(self, max_concurrency: int | None = None) -> None: 44 | """Initialize the scheduler.""" 45 | self._q: asyncio.Queue[_Job] = asyncio.Queue() 46 | self._runner_task: asyncio.Task | None = None 47 | self._sema = asyncio.Semaphore(max_concurrency) if max_concurrency else None 48 | 49 | async def submit(self, fn: Callable[..., Any], *args, **kwargs) -> Any: 50 | """Submit a job to the queue.""" 51 | loop = asyncio.get_running_loop() 52 | fut: asyncio.Future = loop.create_future() 53 | await self._q.put(_Job(fn, args, kwargs, fut)) 54 | return await fut 55 | 56 | async def schedule(self) -> _Job: 57 | """Schedule the next job in the queue.""" 58 | return await self._q.get() 59 | 60 | async def _runner(self) -> None: 61 | """Infinite loop to process jobs in the queue.""" 62 | while True: 63 | job = await self.schedule() 64 | 65 | async def _execute(j: _Job) -> None: 66 | cm = self._sema or _ASYNC_NULL 67 | async with cm: 68 | try: 69 | res = j.fn(*j.args, **j.kwargs) 70 | if asyncio.iscoroutine(res): 71 | res = await res 72 | j.fut.set_result(res) 73 | except Exception as exc: 74 | j.fut.set_exception(exc) 75 | 76 | asyncio.create_task(_execute(job)) 77 | 78 | def start(self) -> None: 79 | """Start the scheduler and begin processing jobs.""" 80 | if self._runner_task is None: 81 | self._runner_task = asyncio.create_task(self._runner()) 82 | 83 | async def stop(self) -> None: 84 | """Stop the scheduler and cancel all jobs in flight.""" 85 | if self._runner_task: 86 | self._runner_task.cancel() 87 | with contextlib.suppress(asyncio.CancelledError): 88 | await self._runner_task 89 | -------------------------------------------------------------------------------- /python/cornserve/services/task_dispatcher/__init__.py: -------------------------------------------------------------------------------- 1 | """The Task Dispatcher is the interface between App Drivers and the data plane. 2 | 3 | It receives requests from App Drivers and dispatches them to the appropriate 4 | task manager. It also receives updates from the resource manager about changes 5 | in Task Managers. 6 | 7 | The Task Dispatcher requires both a gRPC server and a REST HTTP server. 8 | The gRPC server is used by the Resource Manager to send updates about 9 | Task Manager. On the other hand, the REST API server is used by App Drivers 10 | to invoke tasks. 11 | """ 12 | -------------------------------------------------------------------------------- /python/cornserve/services/task_dispatcher/entrypoint.py: -------------------------------------------------------------------------------- 1 | """Spins up the Task Dispatcher service with gRPC and FastAPI.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import signal 7 | from typing import TYPE_CHECKING 8 | 9 | import uvicorn 10 | from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor 11 | from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient, GrpcInstrumentorServer 12 | from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor 13 | 14 | from cornserve.logging import get_logger 15 | from cornserve.services.task_dispatcher.grpc import create_server 16 | from cornserve.services.task_dispatcher.router import create_app 17 | from cornserve.tracing import configure_otel 18 | 19 | if TYPE_CHECKING: 20 | from cornserve.services.task_dispatcher.dispatcher import TaskDispatcher 21 | 22 | logger = get_logger("cornserve.services.task_dispatcher.entrypoint") 23 | 24 | 25 | async def serve() -> None: 26 | """Serve the Task Dispatcher service.""" 27 | logger.info("Starting Gateway service") 28 | 29 | configure_otel("task_dispatcher") 30 | 31 | # FastAPI server 32 | app = create_app() 33 | 34 | FastAPIInstrumentor.instrument_app(app) 35 | HTTPXClientInstrumentor().instrument() 36 | GrpcInstrumentorClient().instrument() 37 | GrpcInstrumentorServer().instrument() 38 | 39 | logger.info("Available HTTP routes are:") 40 | for route in app.routes: 41 | methods = getattr(route, "methods", None) 42 | path = getattr(route, "path", None) 43 | 44 | if methods is None or path is None: 45 | continue 46 | 47 | logger.info( 48 | "%s %s", 49 | list(methods)[0] if len(methods) == 1 else "{" + ",".join(methods) + "}", 50 | path, 51 | ) 52 | 53 | config = uvicorn.Config(app, host="0.0.0.0", port=8000) 54 | uvicorn_server = uvicorn.Server(config) 55 | dispatcher: TaskDispatcher = app.state.dispatcher 56 | 57 | # gRPC server 58 | grpc_server = create_server(dispatcher) 59 | 60 | loop = asyncio.get_running_loop() 61 | uvicorn_server_task = loop.create_task(uvicorn_server.serve()) 62 | await grpc_server.start() 63 | 64 | def shutdown() -> None: 65 | uvicorn_server_task.cancel() 66 | 67 | loop.add_signal_handler(signal.SIGINT, shutdown) 68 | loop.add_signal_handler(signal.SIGTERM, shutdown) 69 | 70 | try: 71 | await uvicorn_server_task 72 | except asyncio.CancelledError: 73 | logger.info("Shutting down Task Dispatcher service") 74 | await dispatcher.shutdown() 75 | await uvicorn_server.shutdown() 76 | await grpc_server.stop(5) 77 | logger.info("Task Dispatcher service shutdown complete") 78 | 79 | 80 | if __name__ == "__main__": 81 | asyncio.run(serve()) 82 | -------------------------------------------------------------------------------- /python/cornserve/services/task_dispatcher/grpc.py: -------------------------------------------------------------------------------- 1 | """Task Dispatcher gRPC server.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import grpc 8 | 9 | from cornserve.logging import get_logger 10 | from cornserve.services.pb import common_pb2, task_dispatcher_pb2, task_dispatcher_pb2_grpc 11 | from cornserve.task.base import UnitTask 12 | 13 | if TYPE_CHECKING: 14 | from cornserve.services.task_dispatcher.dispatcher import TaskDispatcher 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | class TaskDispatcherServicer(task_dispatcher_pb2_grpc.TaskDispatcherServicer): 20 | """Task Dispatcher gRPC service implementation.""" 21 | 22 | def __init__(self, task_dispatcher: TaskDispatcher) -> None: 23 | """Initializer the TaskDispatcherServicer.""" 24 | self.task_dispatcher = task_dispatcher 25 | 26 | async def NotifyUnitTaskDeployment( 27 | self, 28 | request: task_dispatcher_pb2.NotifyUnitTaskDeploymentRequest, 29 | context: grpc.aio.ServicerContext, 30 | ) -> task_dispatcher_pb2.NotifyUnitTaskDeploymentResponse: 31 | """Register new task managers with the task dispatcher.""" 32 | await self.task_dispatcher.notify_task_deployment( 33 | task=UnitTask.from_pb(request.task), 34 | task_manager_url=request.task_manager.url, 35 | ) 36 | return task_dispatcher_pb2.NotifyUnitTaskDeploymentResponse(status=common_pb2.Status.STATUS_OK) 37 | 38 | async def NotifyUnitTaskTeardown( 39 | self, 40 | request: task_dispatcher_pb2.NotifyUnitTaskTeardownRequest, 41 | context: grpc.aio.ServicerContext, 42 | ) -> task_dispatcher_pb2.NotifyUnitTaskTeardownResponse: 43 | """Remove task managers from the task dispatcher.""" 44 | await self.task_dispatcher.notify_task_teardown(task=UnitTask.from_pb(request.task)) 45 | return task_dispatcher_pb2.NotifyUnitTaskTeardownResponse(status=common_pb2.Status.STATUS_OK) 46 | 47 | 48 | def create_server(task_dispatcher: TaskDispatcher) -> grpc.aio.Server: 49 | """Create the gRPC server for the Task Dispatcher.""" 50 | servicer = TaskDispatcherServicer(task_dispatcher) 51 | server = grpc.aio.server() 52 | task_dispatcher_pb2_grpc.add_TaskDispatcherServicer_to_server(servicer, server) 53 | listen_addr = "[::]:50051" 54 | server.add_insecure_port(listen_addr) 55 | logger.info("gRPC server listening on %s", listen_addr) 56 | return server 57 | -------------------------------------------------------------------------------- /python/cornserve/services/task_dispatcher/router.py: -------------------------------------------------------------------------------- 1 | """Task Dispatcher REST API server.""" 2 | 3 | from __future__ import annotations 4 | 5 | from fastapi import APIRouter, FastAPI, Request, Response, status 6 | from opentelemetry import trace 7 | 8 | from cornserve.logging import get_logger 9 | from cornserve.services.task_dispatcher.dispatcher import TaskDispatcher 10 | from cornserve.task.base import TaskGraphDispatch 11 | 12 | router = APIRouter() 13 | logger = get_logger(__name__) 14 | tracer = trace.get_tracer(__name__) 15 | 16 | 17 | @router.post("/task") 18 | async def invoke_task(request: TaskGraphDispatch, raw_request: Request): 19 | """Invoke a task with the given request data.""" 20 | logger.info("Task dispatch received: %s", request) 21 | 22 | dispatcher: TaskDispatcher = raw_request.app.state.dispatcher 23 | try: 24 | response = await dispatcher.invoke(request.invocations) 25 | return response 26 | except Exception as e: 27 | logger.exception("Error while invoking task") 28 | return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=str(e)) 29 | 30 | 31 | @router.get("/health") 32 | async def health_check(): 33 | """Health check endpoint.""" 34 | return Response(status_code=status.HTTP_200_OK) 35 | 36 | 37 | def init_app_state(app: FastAPI) -> None: 38 | """Initialize the app state for the Task Dispatcher.""" 39 | app.state.dispatcher = TaskDispatcher() 40 | 41 | 42 | def create_app() -> FastAPI: 43 | """Build the FastAPI app for the Task Dispatcher.""" 44 | app = FastAPI(title="Cornserve Task Dispatcher") 45 | app.include_router(router) 46 | init_app_state(app) 47 | return app 48 | -------------------------------------------------------------------------------- /python/cornserve/services/task_manager/__init__.py: -------------------------------------------------------------------------------- 1 | """The Task Manager manages task executors. 2 | 3 | A Task Manager handles exactly one type of task, for instance, 4 | multimodal data embedding (Eric) or LLM inference (vLLM). 5 | It spawns and kills task executors given the resource (GPUs) allocated to it by 6 | the resource manager. 7 | 8 | It's primarily responsible for 9 | 1. Spawning and killing task executors 10 | 2. Routing requests to the appropriate task executor 11 | """ 12 | -------------------------------------------------------------------------------- /python/cornserve/services/task_manager/entrypoint.py: -------------------------------------------------------------------------------- 1 | """Entrypoint for the Task Manager service.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import signal 7 | 8 | from cornserve.logging import get_logger 9 | from cornserve.services.task_manager.grpc import create_server 10 | 11 | logger = get_logger("cornserve.services.task_manager.entrypoint") 12 | 13 | 14 | async def serve() -> None: 15 | """Serve the Task Manager service.""" 16 | logger.info("Starting Task Manager service") 17 | 18 | server, servicer = create_server() 19 | await server.start() 20 | 21 | logger.info("gRPC server started") 22 | 23 | loop = asyncio.get_running_loop() 24 | server_task = asyncio.create_task(server.wait_for_termination()) 25 | 26 | def shutdown() -> None: 27 | server_task.cancel() 28 | 29 | loop.add_signal_handler(signal.SIGINT, shutdown) 30 | loop.add_signal_handler(signal.SIGTERM, shutdown) 31 | 32 | try: 33 | await server_task 34 | except asyncio.CancelledError: 35 | logger.info("Shutting down Task Manager service") 36 | await server.stop(5) 37 | if servicer.manager is not None: 38 | logger.info("Shutting down task manager...") 39 | await servicer.manager.shutdown() 40 | logger.info("Task Manager service shutdown complete") 41 | 42 | 43 | if __name__ == "__main__": 44 | asyncio.run(serve()) 45 | -------------------------------------------------------------------------------- /python/cornserve/sidecar/__init__.py: -------------------------------------------------------------------------------- 1 | """Sidecar APIs for Task Executors.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/sidecar/constants.py: -------------------------------------------------------------------------------- 1 | """Sidecar constants.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | 7 | from cornserve import constants 8 | 9 | RANK_OFFSET = 1000000 10 | CHUNK_OFFSET = 1000 11 | GRPC_BASE_PORT = 10000 12 | UCX_BASE_PORT = 12000 13 | 14 | 15 | def chunk_tag(id: str, rank: int, chunk_id: int, shard_rank: int) -> int: 16 | """Generate a tag for the chunk. 17 | 18 | The tag is a unique id for a chunk during transmission. 19 | """ 20 | # convert the hex uuid to int 21 | base = int(id, 16) 22 | return base + RANK_OFFSET * (rank) + CHUNK_OFFSET * (chunk_id) + shard_rank 23 | 24 | 25 | def shm_filename() -> str: 26 | """Shared memory filename in each node.""" 27 | return "/dev/shm/sc_shm" 28 | 29 | 30 | def grpc_url_from_rank(rank: int) -> str: 31 | """GRPC channel url from rank.""" 32 | assert rank >= 0, "Rank should be non-negative" 33 | is_local = os.environ.get("SIDECAR_IS_LOCAL", "false").lower() == "true" 34 | if is_local: 35 | return f"localhost:{GRPC_BASE_PORT + rank}" 36 | parts = [ 37 | f"sidecar-{rank}", 38 | constants.K8S_SIDECAR_SERVICE_NAME, 39 | constants.K8S_NAMESPACE, 40 | "svc.cluster.local", 41 | ] 42 | return ".".join(parts) + f":{GRPC_BASE_PORT + rank}" 43 | 44 | 45 | def ucx_url_from_rank(rank: int) -> str: 46 | """UCX connection host url from rank.""" 47 | assert rank >= 0, "Rank should be non-negative" 48 | is_local = os.environ.get("SIDECAR_IS_LOCAL", "false").lower() == "true" 49 | if is_local: 50 | return "127.0.0.1" 51 | parts = [ 52 | f"sidecar-{rank}", 53 | constants.K8S_SIDECAR_SERVICE_NAME, 54 | constants.K8S_NAMESPACE, 55 | "svc.cluster.local", 56 | ] 57 | return ".".join(parts) 58 | 59 | 60 | def ucx_port_from_rank(rank: int) -> int: 61 | """UCX connection host port from rank.""" 62 | assert rank >= 0, "Rank should be non-negative" 63 | return UCX_BASE_PORT + rank 64 | -------------------------------------------------------------------------------- /python/cornserve/sidecar/utils.py: -------------------------------------------------------------------------------- 1 | """Sidecar utility functions and constants.""" 2 | 3 | from __future__ import annotations 4 | 5 | import ctypes 6 | 7 | import torch 8 | 9 | from cornserve.logging import get_logger 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def buffer_from_tensor(t: torch.Tensor) -> ctypes.Array: 15 | """Convert a torch tensor to a ctypes buffer for ucx-py.""" 16 | data_ptr = t.data_ptr() 17 | size_bytes = t.numel() * t.element_size() 18 | buffer = (ctypes.c_byte * size_bytes).from_address(data_ptr) 19 | return buffer 20 | 21 | 22 | def device_from_rank(rank: int) -> torch.device: 23 | """Torch device from rank.""" 24 | assert rank >= 0, "Rank should be non-negative" 25 | return torch.device(f"cuda:{rank}") 26 | 27 | 28 | def init_shmem( 29 | filename: str, 30 | local_ranks: list[int], 31 | num_local_sidecars: int, 32 | partition_numel: int, 33 | dtype: torch.dtype, 34 | ) -> tuple[torch.Tensor, torch.Tensor]: 35 | """Initialize a shared memory buffer between the sidecar client and server. 36 | 37 | All sidecars within the same node will share the same buffer but at different offsets. 38 | Each sidecar will only access its own slice of the buffer, and each slice has the same size. 39 | 40 | Args: 41 | filename: The filename of the shared memory buffer. 42 | local_ranks: The local ranks of the sidecars that will share the buffer, must be consecutive. 43 | num_local_sidecars: Total number of sidecars within the same node. 44 | partition_numel: Number of elements of given dtype in the shared memory buffer used by each device/sidecar. 45 | dtype: Data type of the shared memory buffer. 46 | """ 47 | # sanity check device_ids 48 | for i in range(len(local_ranks) - 1): 49 | assert local_ranks[i] + 1 == local_ranks[i + 1], "Device IDs must be consecutive" 50 | total_element_count = partition_numel * num_local_sidecars 51 | full_tensor = torch.from_file( 52 | filename=filename, 53 | shared=True, 54 | size=total_element_count, 55 | dtype=dtype, 56 | ) 57 | start = partition_numel * local_ranks[0] 58 | end = partition_numel * (local_ranks[-1] + 1) 59 | return full_tensor, full_tensor[start:end] 60 | -------------------------------------------------------------------------------- /python/cornserve/task/__init__.py: -------------------------------------------------------------------------------- 1 | """The Task abstraction.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/task/builtins/__init__.py: -------------------------------------------------------------------------------- 1 | """Built-in tasks.""" 2 | 3 | from cornserve.task.builtins import encoder, llm 4 | -------------------------------------------------------------------------------- /python/cornserve/task/builtins/encoder.py: -------------------------------------------------------------------------------- 1 | """Build-in task for modality encoders.""" 2 | 3 | from __future__ import annotations 4 | 5 | import enum 6 | 7 | from cornserve.task.base import TaskInput, TaskOutput, UnitTask 8 | from cornserve.task.forward import DataForward, Tensor 9 | 10 | 11 | class Modality(enum.StrEnum): 12 | """Supported modalities for encoder tasks.""" 13 | 14 | IMAGE = "image" 15 | VIDEO = "video" 16 | AUDIO = "audio" 17 | 18 | 19 | class EncoderInput(TaskInput): 20 | """Input model for encoder tasks. 21 | 22 | Attributes: 23 | data_urls: The URLs of the data to encode. 24 | """ 25 | 26 | data_urls: list[str] 27 | 28 | 29 | class EncoderOutput(TaskOutput): 30 | """Output model for encoder tasks. 31 | 32 | Attributes: 33 | embeddings: The embeddings from the encoder. 34 | """ 35 | 36 | embeddings: list[DataForward[Tensor]] 37 | 38 | 39 | class EncoderTask(UnitTask[EncoderInput, EncoderOutput]): 40 | """A task that invokes an encoder. 41 | 42 | Attributes: 43 | model_id: The ID of the model to use for the task. 44 | modality: Modality of data this encoder can embed. 45 | """ 46 | 47 | model_id: str 48 | modality: Modality 49 | 50 | def make_record_output(self, task_input: EncoderInput) -> EncoderOutput: 51 | """Create a task output for task invocation recording.""" 52 | return EncoderOutput(embeddings=[DataForward[Tensor]() for _ in task_input.data_urls]) 53 | 54 | def make_name(self) -> str: 55 | """Create a concise string representation of the task.""" 56 | return f"encoder-{self.modality.lower()}-{self.model_id.split('/')[-1].lower()}" 57 | -------------------------------------------------------------------------------- /python/cornserve/task/builtins/llm.py: -------------------------------------------------------------------------------- 1 | """Build-in task for LLMs.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Generic, TypeVar 6 | 7 | from cornserve.task.base import TaskInput, TaskOutput, UnitTask 8 | from cornserve.task.forward import DataForward, Tensor 9 | 10 | 11 | class LLMInput(TaskInput): 12 | """Input model for LLM tasks. 13 | 14 | Attributes: 15 | prompt: The prompt to send to the LLM. 16 | multimodal_data: List of tuples (modality, data URL). 17 | "image", "audio", "video", etc. for modality. 18 | embeddings: Multimodal embeddings to send to the LLM. 19 | """ 20 | 21 | prompt: str 22 | multimodal_data: list[tuple[str, str]] = [] 23 | embeddings: list[DataForward[Tensor]] = [] 24 | 25 | 26 | class LLMOutputBase(TaskOutput): 27 | """Base output model for LLM tasks.""" 28 | 29 | 30 | InputT = TypeVar("InputT", bound=TaskInput) 31 | OutputT = TypeVar("OutputT", bound=TaskOutput) 32 | 33 | 34 | class LLMBaseTask(UnitTask[InputT, OutputT], Generic[InputT, OutputT]): 35 | """A task that invokes an LLM. 36 | 37 | Attributes: 38 | model_id: The ID of the model to use for the task. 39 | """ 40 | 41 | model_id: str 42 | 43 | def make_name(self) -> str: 44 | """Create a concise string representation of the task.""" 45 | return f"llm-{self.model_id.split('/')[-1].lower()}" 46 | 47 | 48 | class LLMOutput(LLMOutputBase): 49 | """Output model for LLM tasks. 50 | 51 | Attributes: 52 | response: The response from the LLM. 53 | """ 54 | 55 | response: str 56 | 57 | 58 | class LLMTask(LLMBaseTask[LLMInput, LLMOutput]): 59 | """A task that invokes an LLM and returns the response. 60 | 61 | Attributes: 62 | model_id: The ID of the model to use for the task. 63 | """ 64 | 65 | model_id: str 66 | 67 | def make_record_output(self, task_input: LLMInput) -> LLMOutput: 68 | """Create a task output for task invocation recording.""" 69 | return LLMOutput(response="") 70 | 71 | 72 | class LLMForwardOutput(LLMOutputBase): 73 | """Output model for LLM tasks with the response forwarded. 74 | 75 | Attributes: 76 | response: The response from the LLM. 77 | """ 78 | 79 | response: DataForward[str] 80 | 81 | 82 | class LLMForwardOutputTask(LLMBaseTask[LLMInput, LLMForwardOutput]): 83 | """A task that invokes an LLM and forwards the response. 84 | 85 | Attributes: 86 | model_id: The ID of the model to use for the task. 87 | """ 88 | 89 | model_id: str 90 | 91 | def make_record_output(self, task_input: LLMInput) -> LLMForwardOutput: 92 | """Create a task output for task invocation recording.""" 93 | return LLMForwardOutput(response=DataForward[str]()) 94 | -------------------------------------------------------------------------------- /python/cornserve/task/builtins/mllm.py: -------------------------------------------------------------------------------- 1 | """Built-in task for Multimodal LLMs.""" 2 | 3 | from __future__ import annotations 4 | 5 | from cornserve.task.base import Task, TaskInput, TaskOutput 6 | from cornserve.task.builtins.encoder import EncoderInput, EncoderTask, Modality 7 | from cornserve.task.builtins.llm import LLMInput, LLMTask 8 | from cornserve.task.forward import DataForward, Tensor 9 | 10 | 11 | class MLLMInput(TaskInput): 12 | """Input model for Multimodal LLM tasks. 13 | 14 | Attributes: 15 | prompt: The prompt to send to the LLM. 16 | multimodal_data: List of tuples (modality, data URL). 17 | "image", "audio", "video", etc. for modality. 18 | """ 19 | 20 | prompt: str 21 | multimodal_data: list[tuple[str, str]] = [] 22 | 23 | 24 | class MLLMOutput(TaskOutput): 25 | """Output model for Multimodal LLM tasks. 26 | 27 | Attributes: 28 | response: The response from the LLM. 29 | """ 30 | 31 | response: str 32 | 33 | 34 | class MLLMTask(Task[MLLMInput, MLLMOutput]): 35 | """A task that invokes a Multimodal LLM. 36 | 37 | Attributes: 38 | model_id: The ID of the model to use for the task. 39 | modalities: List of input modalities other than text. 40 | """ 41 | 42 | model_id: str 43 | modalities: list[Modality] = [] 44 | 45 | def post_init(self) -> None: 46 | """Initialize subtasks.""" 47 | if Modality.IMAGE in self.modalities: 48 | self.image_encoder = EncoderTask(model_id=self.model_id, modality=Modality.IMAGE) 49 | if Modality.VIDEO in self.modalities: 50 | self.video_encoder = EncoderTask(model_id=self.model_id, modality=Modality.VIDEO) 51 | self.llm = LLMTask(model_id=self.model_id) 52 | 53 | def invoke(self, task_input: MLLMInput) -> MLLMOutput: 54 | """Invoke the task. 55 | 56 | Given multimodal data and a text prompt, run the corresponding encoder 57 | for multimodal data and then pass the embeddings and text prompt to the LLM. 58 | """ 59 | image_data = [] 60 | video_data = [] 61 | for modality, data in task_input.multimodal_data: 62 | if modality == Modality.IMAGE: 63 | image_data.append(data) 64 | elif modality == Modality.VIDEO: 65 | video_data.append(data) 66 | else: 67 | raise ValueError(f"Unsupported modality: {modality}") 68 | 69 | if image_data: 70 | if not hasattr(self, "image_encoder"): 71 | raise ValueError("Image modality is not supported.") 72 | image_task_input = EncoderInput(data_urls=image_data) 73 | image_embeddings = self.image_encoder.invoke(image_task_input).embeddings 74 | else: 75 | image_embeddings = [] 76 | 77 | if video_data: 78 | if not hasattr(self, "video_encoder"): 79 | raise ValueError("Video modality is not supported.") 80 | video_task_input = EncoderInput(data_urls=video_data) 81 | video_embeddings = self.video_encoder.invoke(video_task_input).embeddings 82 | else: 83 | video_embeddings = [] 84 | 85 | # Retain the order of multimodal data 86 | embeddings: list[DataForward[Tensor]] = [] 87 | for modality, _ in task_input.multimodal_data: 88 | if modality == Modality.IMAGE: 89 | embeddings.append(image_embeddings.pop(0)) 90 | elif modality == Modality.VIDEO: 91 | embeddings.append(video_embeddings.pop(0)) 92 | 93 | llm_input = LLMInput( 94 | prompt=task_input.prompt, 95 | multimodal_data=task_input.multimodal_data, 96 | embeddings=embeddings, 97 | ) 98 | llm_output = self.llm.invoke(llm_input) 99 | 100 | return MLLMOutput(response=llm_output.response) 101 | -------------------------------------------------------------------------------- /python/cornserve/task/forward.py: -------------------------------------------------------------------------------- 1 | """Classes for representing data forwarding between tasks in the data plane.""" 2 | 3 | from __future__ import annotations 4 | 5 | import enum 6 | import uuid 7 | from typing import Generic, Self, TypeVar 8 | 9 | from pydantic import BaseModel, Field, model_validator 10 | 11 | 12 | class Tensor: 13 | """Represents a tensor object for data forwarding.""" 14 | 15 | 16 | class ForwardableType(enum.StrEnum): 17 | """Types of data that can be forwarded between tasks.""" 18 | 19 | BYTES = "bytes" 20 | STR = "str" 21 | INT = "int" 22 | FLOAT = "float" 23 | BOOL = "bool" 24 | TENSOR = "Tensor" 25 | 26 | 27 | DataT = TypeVar("DataT") 28 | 29 | 30 | class DataForward(BaseModel, Generic[DataT]): 31 | """Represents data that is forwarded between tasks in the data plane.""" 32 | 33 | # This ID identifies `DataForward` objects and ties them together in task input/outputs. 34 | id: str = Field(default_factory=lambda: uuid.uuid4().hex) 35 | 36 | # The data type automatically parsed out of the generic type argument. 37 | data_type: ForwardableType = Field(init=False, default=ForwardableType.TENSOR) 38 | 39 | # Producer (source) sidecar ranks. 40 | src_sidecar_ranks: list[int] | None = Field(init=False, default=None) 41 | 42 | # Consumer (destination) sidecar ranks. This is a list of lists because the data 43 | # can be forwarded to multiple tasks (i.e., broadcasted) to more than one task executor. 44 | dst_sidecar_ranks: list[list[int]] | None = Field(init=False, default=None) 45 | 46 | @model_validator(mode="after") 47 | def _data_type(self) -> Self: 48 | """Validate the generic type argument. 49 | 50 | 1. The generic type argument must be present. 51 | 2. It should be one of the forwardable types (`ForwardableType`). 52 | """ 53 | metadata = self.__class__.__pydantic_generic_metadata__ 54 | if metadata["origin"] is None: 55 | raise ValueError("Generic type argument is missing.") 56 | 57 | args = metadata["args"] 58 | assert len(args) == 1, "If origin was not None, there should be exactly one argument." 59 | self.data_type = ForwardableType(args[0].__name__) 60 | 61 | return self 62 | -------------------------------------------------------------------------------- /python/cornserve/task/registry.py: -------------------------------------------------------------------------------- 1 | """Tasks registered and known to the system.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | from cornserve.task.base import TaskInput, TaskOutput, UnitTask 9 | 10 | 11 | class TaskRegistry: 12 | """Registry of unit tasks. 13 | 14 | Composite tasks are not registered here; they can simply be invoked 15 | and will be decomposed into a series of unit task invocations. 16 | """ 17 | 18 | def __init__(self) -> None: 19 | """Initialize the task registry.""" 20 | self._tasks: dict[str, tuple[type[UnitTask], type[TaskInput], type[TaskOutput]]] = {} 21 | 22 | def register( 23 | self, 24 | task: type[UnitTask], 25 | task_input: type[TaskInput], 26 | task_output: type[TaskOutput], 27 | name: str | None = None, 28 | ) -> None: 29 | """Register a task with the given ID.""" 30 | name = name or task.__name__ 31 | if name in self._tasks: 32 | raise ValueError(f"Unit task with {name=} already exists. Unit task names must be unique.") 33 | self._tasks[name] = (task, task_input, task_output) 34 | 35 | def get(self, name: str) -> tuple[type[UnitTask], type[TaskInput], type[TaskOutput]]: 36 | """Get a task by its name.""" 37 | # Lazy import builtin tasks to avoid circular import issues 38 | import cornserve.task.builtins # noqa: F401 39 | 40 | if name not in self._tasks: 41 | raise KeyError(f"Unit task with {name=} not found") 42 | return self._tasks[name] 43 | 44 | def __contains__(self, name: str) -> bool: 45 | """Check if a task is registered.""" 46 | return name in self._tasks 47 | 48 | 49 | TASK_REGISTRY = TaskRegistry() 50 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/__init__.py: -------------------------------------------------------------------------------- 1 | """Built-in task executors in Cornserve. 2 | 3 | - Eric: A server that embeds modality data so that they can be fed into LLMs. 4 | """ 5 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/descriptor/__init__.py: -------------------------------------------------------------------------------- 1 | """Task Execution Descriptor. 2 | 3 | The `TaskExecutionDescriptor` class describes *how* to execute a `Task`. 4 | The same `Task` may have multiple compatible `TaskExecutionDescriptor`s. 5 | For instance, the builtin `LLMTask` can be executed with a monolithic 6 | vLLM instance, but can also be executed with prefill-decode disaggregation. 7 | 8 | A descriptor is compatible with a `Task` when it inherits from the base 9 | descriptor class annotated in the `Task` class's `execution_descriptor` field. 10 | 11 | The descriptor exposes the following information: 12 | - Dominant resource type: GPU, CPU, memory, disk, etc. 13 | This is not implemented yet; all task executors consume GPU resources. 14 | - Chunking and pipelining semantics: Information about what kind of chunking 15 | and pipelining is supported by the task executor for its input and output. 16 | - Launch information: Information about how to launch the task executor. 17 | - Request and response transformation: How to transform TaskInput to the actual 18 | Task Executor request object, and the Task Executor response object to the 19 | TaskOutput object. 20 | """ 21 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/descriptor/base.py: -------------------------------------------------------------------------------- 1 | """Base task execution descriptor class.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Any, Generic, TypeVar 7 | 8 | from pydantic import BaseModel 9 | 10 | from cornserve import constants 11 | from cornserve.services.resource_manager.resource import GPU 12 | from cornserve.task.base import TaskInput, TaskOutput, UnitTask 13 | 14 | TaskT = TypeVar("TaskT", bound=UnitTask) 15 | InputT = TypeVar("InputT", bound=TaskInput) 16 | OutputT = TypeVar("OutputT", bound=TaskOutput) 17 | 18 | 19 | class TaskExecutionDescriptor(BaseModel, ABC, Generic[TaskT, InputT, OutputT]): 20 | """Base class for task execution descriptors. 21 | 22 | Attributes: 23 | task: The task to be executed. 24 | """ 25 | 26 | task: TaskT 27 | 28 | @abstractmethod 29 | def create_executor_name(self) -> str: 30 | """Create a name for the task executor.""" 31 | 32 | @abstractmethod 33 | def get_container_image(self) -> str: 34 | """Get the container image name for the task executor.""" 35 | 36 | @abstractmethod 37 | def get_container_args(self, gpus: list[GPU], port: int) -> list[str]: 38 | """Get the container command for the task executor.""" 39 | 40 | def get_container_volumes(self) -> list[tuple[str, str, str]]: 41 | """Get the container volumes for the task manager. 42 | 43 | Returns: 44 | A list of tuples: name, host path, container path. 45 | """ 46 | return [ 47 | ("hf-cache", constants.VOLUME_HF_CACHE, "/root/.cache/huggingface"), 48 | ("shm", constants.VOLUME_SHM, "/dev/shm"), 49 | ] 50 | 51 | @abstractmethod 52 | def get_api_url(self, base: str) -> str: 53 | """Get the task executor's base URL for API calls. 54 | 55 | Args: 56 | base: The base URL of the task executor. 57 | """ 58 | 59 | @abstractmethod 60 | def to_request(self, task_input: InputT, task_output: OutputT) -> dict[str, Any]: 61 | """Convert TaskInput to a request object for the task executor. 62 | 63 | The task output object is needed because this specific task executor may 64 | have to forward data to the next task executor, and for that, we need to 65 | know the destination sidecar ranks annotated in the task output. 66 | """ 67 | 68 | @abstractmethod 69 | def from_response(self, task_output: OutputT, response: dict[str, Any]) -> OutputT: 70 | """Convert the task executor response to TaskOutput. 71 | 72 | In general, the `task_output` object will be deep-copied and concrete values 73 | will be filled in from the response. 74 | """ 75 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/descriptor/builtins/__init__.py: -------------------------------------------------------------------------------- 1 | """Built-in task execution descriptors.""" 2 | 3 | from cornserve.task_executors.descriptor.builtins import encoder, llm 4 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/descriptor/builtins/encoder.py: -------------------------------------------------------------------------------- 1 | """Built-in task execution descriptor for Encoder tasks.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | from cornserve import constants 8 | from cornserve.services.resource_manager.resource import GPU 9 | from cornserve.task.builtins.encoder import EncoderInput, EncoderOutput, EncoderTask 10 | from cornserve.task_executors.descriptor.base import TaskExecutionDescriptor 11 | from cornserve.task_executors.descriptor.registry import DESCRIPTOR_REGISTRY 12 | from cornserve.task_executors.eric.api import EmbeddingData, EmbeddingRequest, EmbeddingResponse, Modality, Status 13 | 14 | 15 | class EricDescriptor(TaskExecutionDescriptor[EncoderTask, EncoderInput, EncoderOutput]): 16 | """Task execution descriptor for Encoder tasks. 17 | 18 | This descriptor handles launching Eric (multimodal encoder) tasks and converting between 19 | the external task API types and internal executor types. 20 | """ 21 | 22 | def create_executor_name(self) -> str: 23 | """Create a name for the task executor.""" 24 | name = "-".join( 25 | [ 26 | "eric", 27 | self.task.modality, 28 | self.task.model_id.split("/")[-1].lower(), 29 | ] 30 | ).lower() 31 | return name 32 | 33 | def get_container_image(self) -> str: 34 | """Get the container image name for the task executor.""" 35 | return constants.CONTAINER_IMAGE_ERIC 36 | 37 | def get_container_args(self, gpus: list[GPU], port: int) -> list[str]: 38 | """Get the container command for the task executor.""" 39 | # fmt: off 40 | cmd = [ 41 | "--model.id", self.task.model_id, 42 | "--model.tp-size", str(len(gpus)), 43 | "--model.modality", self.task.modality.value.upper(), 44 | "--server.port", str(port), 45 | "--sidecar.ranks", *[str(gpu.global_rank) for gpu in gpus], 46 | ] 47 | # fmt: on 48 | return cmd 49 | 50 | def get_api_url(self, base: str) -> str: 51 | """Get the task executor's base URL for API calls.""" 52 | return f"{base}/embeddings" 53 | 54 | def to_request(self, task_input: EncoderInput, task_output: EncoderOutput) -> dict[str, Any]: 55 | """Convert TaskInput to a request object for the task executor.""" 56 | data: list[EmbeddingData] = [] 57 | for url, forward in zip(task_input.data_urls, task_output.embeddings, strict=True): 58 | if forward.dst_sidecar_ranks is None: 59 | raise ValueError("Destination sidecar ranks must be specified for each forward.") 60 | data.append( 61 | EmbeddingData( 62 | id=forward.id, 63 | modality=Modality(self.task.modality.value), 64 | url=url, 65 | receiver_sidecar_ranks=forward.dst_sidecar_ranks, 66 | ) 67 | ) 68 | req = EmbeddingRequest(data=data) 69 | return req.model_dump() 70 | 71 | def from_response(self, task_output: EncoderOutput, response: dict[str, Any]) -> EncoderOutput: 72 | """Convert the task executor response to TaskOutput.""" 73 | resp = EmbeddingResponse.model_validate(response) 74 | if resp.status == Status.SUCCESS: 75 | return EncoderOutput(embeddings=task_output.embeddings) 76 | else: 77 | raise RuntimeError(f"Error in encoder task: {resp.error_message}") 78 | 79 | 80 | DESCRIPTOR_REGISTRY.register(EncoderTask, EricDescriptor, default=True) 81 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/descriptor/builtins/llm.py: -------------------------------------------------------------------------------- 1 | """Built-in task execution descriptor for LLM tasks.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | from cornserve import constants 8 | from cornserve.services.resource_manager.resource import GPU 9 | from cornserve.task.builtins.llm import LLMBaseTask, LLMForwardOutput, LLMInput, LLMOutput, LLMOutputBase 10 | from cornserve.task_executors.descriptor.base import TaskExecutionDescriptor 11 | from cornserve.task_executors.descriptor.registry import DESCRIPTOR_REGISTRY 12 | 13 | 14 | class VLLMDescriptor(TaskExecutionDescriptor[LLMBaseTask, LLMInput, LLMOutputBase]): 15 | """Task execution descriptor for Encoder tasks. 16 | 17 | This descriptor handles launching Eric (multimodal encoder) tasks and converting between 18 | the external task API types and internal executor types. 19 | """ 20 | 21 | def create_executor_name(self) -> str: 22 | """Create a name for the task executor.""" 23 | return "-".join(["vllm", self.task.model_id.split("/")[-1]]).lower() 24 | 25 | def get_container_image(self) -> str: 26 | """Get the container image name for the task executor.""" 27 | return constants.CONTAINER_IMAGE_VLLM 28 | 29 | def get_container_args(self, gpus: list[GPU], port: int) -> list[str]: 30 | """Get the container command for the task executor.""" 31 | # fmt: off 32 | cmd = [ 33 | self.task.model_id, 34 | "--tensor-parallel-size", str(len(gpus)), 35 | "--port", str(port), 36 | "--limit-mm-per-prompt", "image=5", # TODO: Is this still needed? 37 | "--cornserve-sidecar-ranks", *[str(gpu.global_rank) for gpu in gpus], 38 | ] 39 | # fmt: on 40 | return cmd 41 | 42 | def get_api_url(self, base: str) -> str: 43 | """Get the task executor's base URL for API calls.""" 44 | return f"{base}/v1/chat/completions" 45 | 46 | def to_request(self, task_input: LLMInput, task_output: LLMOutputBase) -> dict[str, Any]: 47 | """Convert TaskInput to a request object for the task executor.""" 48 | # XXX: `DataForward[str]` not supported yet. 49 | # Compatibility with OpenAI Chat Completion API is kept. 50 | content: list[dict[str, Any]] = [dict(type="text", text=task_input.prompt)] 51 | for (modality, data_url), forward in zip(task_input.multimodal_data, task_input.embeddings, strict=True): 52 | data_uri = f"data:{modality}/uuid;data_id={forward.id};url={data_url}," 53 | content.append({"type": f"{modality}_url", f"{modality}_url": {"url": data_uri}}) 54 | 55 | request = dict( 56 | model=self.task.model_id, 57 | messages=[dict(role="user", content=content)], 58 | max_completion_tokens=512, 59 | ) 60 | 61 | return request 62 | 63 | def from_response(self, task_output: LLMOutputBase, response: dict[str, Any]) -> LLMOutputBase: 64 | """Convert the task executor response to TaskOutput.""" 65 | if isinstance(task_output, LLMOutput): 66 | return LLMOutput(response=response["choices"][0]["message"]["content"]) 67 | if isinstance(task_output, LLMForwardOutput): 68 | return LLMForwardOutput( 69 | response=task_output.response, 70 | ) 71 | raise ValueError(f"Unexpected task output type: {type(task_output)}") 72 | 73 | 74 | DESCRIPTOR_REGISTRY.register(LLMBaseTask, VLLMDescriptor, default=True) 75 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/descriptor/registry.py: -------------------------------------------------------------------------------- 1 | """Task execution descriptor registry. 2 | 3 | Task execution descriptor classes register themselves to the registry 4 | specifying a task they can execute. Exactly one descriptor class per 5 | task is marked as the default descriptor class. The registry is used 6 | to look up the descriptor class for a task when executing it. 7 | 8 | `DESCRIPTOR_REGISTRY` is a singleton instance of the registry. 9 | """ 10 | 11 | from __future__ import annotations 12 | 13 | from collections import defaultdict 14 | from typing import TYPE_CHECKING 15 | 16 | if TYPE_CHECKING: 17 | from cornserve.task.base import UnitTask 18 | from cornserve.task_executors.descriptor.base import TaskExecutionDescriptor 19 | 20 | DEFAULT = "__default_descriptor__" 21 | 22 | 23 | class TaskExecutionDescriptorRegistry: 24 | """Registry for task execution descriptors.""" 25 | 26 | def __init__(self) -> None: 27 | """Initialize the registry.""" 28 | self.registry: dict[type[UnitTask], dict[str, type[TaskExecutionDescriptor]]] = defaultdict(dict) 29 | self.default_registry: dict[type[UnitTask], type[TaskExecutionDescriptor]] = {} 30 | 31 | def register( 32 | self, 33 | task: type[UnitTask], 34 | descriptor: type[TaskExecutionDescriptor], 35 | name: str | None = None, 36 | default: bool = False, 37 | ) -> None: 38 | """Register a task execution descriptor. 39 | 40 | Args: 41 | task: The task class to register the descriptor for. 42 | descriptor: The task execution descriptor class. 43 | name: The name of the descriptor. If None, use the class name. 44 | default: Whether this is the default descriptor for the task. 45 | """ 46 | if name is None: 47 | name = descriptor.__name__ 48 | 49 | if name in self.registry[task]: 50 | raise ValueError(f"Descriptor {name} already registered for task {task.__name__}") 51 | 52 | self.registry[task][name] = descriptor 53 | 54 | if default: 55 | if task in self.default_registry: 56 | raise ValueError(f"Default descriptor already registered for task {task.__name__}") 57 | self.default_registry[task] = descriptor 58 | 59 | def get(self, task: type[UnitTask], name: str | None = None) -> type[TaskExecutionDescriptor]: 60 | """Get the task execution descriptor for a task. 61 | 62 | Args: 63 | task: The task class to get the descriptor for. 64 | name: The name of the descriptor. If None, use the default descriptor. 65 | """ 66 | # Lazily import built-in descriptors to avoid circular imports 67 | import cornserve.task_executors.descriptor.builtins # noqa: F401 68 | 69 | if task not in self.registry: 70 | raise ValueError(f"No descriptors registered for task {task.__name__}") 71 | 72 | if name is None: 73 | if task not in self.default_registry: 74 | raise ValueError(f"No default descriptor registered for task {task.__name__}") 75 | return self.default_registry[task] 76 | 77 | if name not in self.registry[task]: 78 | raise ValueError(f"Descriptor {name} not registered for task {task.__name__}") 79 | 80 | return self.registry[task][name] 81 | 82 | 83 | DESCRIPTOR_REGISTRY = TaskExecutionDescriptorRegistry() 84 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/__init__.py: -------------------------------------------------------------------------------- 1 | """Eric embeds modality data so that they can be fed into LLMs.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/api.py: -------------------------------------------------------------------------------- 1 | """API schema for Eric.""" 2 | 3 | from __future__ import annotations 4 | 5 | import enum 6 | 7 | from pydantic import BaseModel 8 | 9 | ID = str 10 | 11 | 12 | class Modality(enum.StrEnum): 13 | """Modality of the data to be embedded.""" 14 | 15 | IMAGE = "image" 16 | VIDEO = "video" 17 | AUDIO = "audio" 18 | 19 | 20 | class EmbeddingData(BaseModel): 21 | """The data to be embedded. 22 | 23 | Attributes: 24 | id: Modality data ID unique within the request. 25 | modality: The modality of the data. 26 | url: The URL where the data can be downloaded from. 27 | receiver_sidecar_ranks: List of sidecar ranks to send the embeddings to. 28 | If omitted, tensors will not be sent to any sidecar. 29 | """ 30 | 31 | id: ID 32 | modality: Modality 33 | url: str 34 | receiver_sidecar_ranks: list[list[int]] | None = None 35 | 36 | 37 | class EmbeddingRequest(BaseModel): 38 | """Request to embed data. 39 | 40 | Attributes: 41 | data: List of data to be embedded. 42 | """ 43 | 44 | data: list[EmbeddingData] 45 | 46 | 47 | class Status(enum.IntEnum): 48 | """Status of various operations.""" 49 | 50 | SUCCESS = 0 51 | ERROR = 1 52 | 53 | 54 | class EmbeddingResponse(BaseModel): 55 | """Response containing the embedding.""" 56 | 57 | status: Status 58 | error_message: str | None = None 59 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/engine/__init__.py: -------------------------------------------------------------------------------- 1 | """The engine runs in a separate process and handles model inference.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/entrypoint.py: -------------------------------------------------------------------------------- 1 | """Spins up Eric.""" 2 | 3 | from __future__ import annotations 4 | 5 | import asyncio 6 | import signal 7 | 8 | import tyro 9 | import uvicorn 10 | from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor 11 | from opentelemetry.instrumentation.threading import ThreadingInstrumentor 12 | 13 | from cornserve.logging import get_logger 14 | from cornserve.task_executors.eric.config import EricConfig 15 | from cornserve.task_executors.eric.engine.client import EngineClient 16 | from cornserve.task_executors.eric.router.app import create_app 17 | from cornserve.tracing import configure_otel 18 | 19 | logger = get_logger("cornserve.task_executors.eric.entrypoint") 20 | 21 | 22 | async def serve(eric_config: EricConfig) -> None: 23 | """Serve the Eric model as a FastAPI app.""" 24 | logger.info("Starting Eric with %s", eric_config) 25 | 26 | configure_otel(f"eric{str(eric_config.sidecar.ranks).replace(' ', '')}") 27 | 28 | app = create_app(eric_config) 29 | 30 | FastAPIInstrumentor().instrument_app(app) 31 | ThreadingInstrumentor().instrument() 32 | 33 | logger.info("Available routes are:") 34 | for route in app.routes: 35 | methods = getattr(route, "methods", None) 36 | path = getattr(route, "path", None) 37 | 38 | if methods is None or path is None: 39 | continue 40 | 41 | logger.info( 42 | "%s %s", 43 | list(methods)[0] if len(methods) == 1 else "{" + ",".join(methods) + "}", 44 | path, 45 | ) 46 | 47 | config = uvicorn.Config(app, host=eric_config.server.host, port=eric_config.server.port) 48 | server = uvicorn.Server(config) 49 | 50 | loop = asyncio.get_running_loop() 51 | server_task = loop.create_task(server.serve()) 52 | 53 | def shutdown() -> None: 54 | engine_client: EngineClient = app.state.engine_client 55 | engine_client.shutdown() 56 | server_task.cancel() 57 | 58 | loop.add_signal_handler(signal.SIGINT, shutdown) 59 | loop.add_signal_handler(signal.SIGTERM, shutdown) 60 | 61 | try: 62 | await server_task 63 | except asyncio.CancelledError: 64 | logger.info("Shutting down FastAPI server.") 65 | await server.shutdown() 66 | 67 | 68 | if __name__ == "__main__": 69 | asyncio.run(serve(tyro.cli(EricConfig))) 70 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/executor/__init__.py: -------------------------------------------------------------------------------- 1 | """The executor manages workers that execute inference.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/models/__init__.py: -------------------------------------------------------------------------------- 1 | """The Eric model zoo. 2 | 3 | Steps for adding a new model: 4 | 5 | 1. Create a module inside `models` named after the model type (`hf_config.model_type`). 6 | 2. Implement the model class inheriting from `models.base.EricModel`. 7 | 3. Implement a class exactly called `ModalityProcessor` in the module, inheriting from 8 | `models.base.BaseModalityProcessor`. For each supported modality, implement 9 | the corresponding method (`get_image_processor`, `get_video_processor`, etc.). 10 | 4. Add an entry in `models.registry.MODEL_REGISTRY` for the model type. 11 | """ 12 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/models/base.py: -------------------------------------------------------------------------------- 1 | """Base class for all models in Eric.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Callable 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy.typing as npt 11 | 12 | from cornserve.task_executors.eric.schema import Modality 13 | 14 | 15 | class EricModel(nn.Module, ABC): 16 | """Base class for all models in Eric.""" 17 | 18 | @abstractmethod 19 | def forward(self, modality: Modality, batch: dict[str, list[torch.Tensor]]) -> list[torch.Tensor]: 20 | """Forward pass for the model. 21 | 22 | Args: 23 | modality: The modality of the data. 24 | batch: The input data. 25 | 26 | Returns: 27 | A list of output tensors. 28 | """ 29 | 30 | @property 31 | @abstractmethod 32 | def dtype(self) -> torch.dtype: 33 | """Return the data type of the model's embeddings.""" 34 | 35 | @property 36 | @abstractmethod 37 | def device(self) -> torch.device: 38 | """Return the device where inputs should be in.""" 39 | 40 | @property 41 | @abstractmethod 42 | def chunk_shape(self) -> tuple[int, ...]: 43 | """Return the shape of the chunks to be sent to the sidecar.""" 44 | 45 | 46 | class BaseModalityProcessor: 47 | """Base class for modality processors. 48 | 49 | Each model definition module contains a `ModalityProcessor` class that 50 | inherits from this class. It should override `get_image_processor`, 51 | `get_video_processor`, etc. to return the appropriate processor for the 52 | given modality. The processor should be a callable that takes the input 53 | modality data as a Numpy array and returns the processed data as a 54 | dictionary of Numpy arrays. 55 | """ 56 | 57 | def __init__(self, model_id: str) -> None: 58 | """Initialize the processor.""" 59 | self.model_id = model_id 60 | 61 | def get_image_processor(self) -> Callable | None: 62 | """Get the image processor for this modality. 63 | 64 | The callable sould take a single image numpy array. 65 | """ 66 | return None 67 | 68 | def get_audio_processor(self) -> Callable | None: 69 | """Get the audio processor for this modality. 70 | 71 | The callable should take a tuple of (audio data numpy array, sample rate). 72 | """ 73 | return None 74 | 75 | def get_video_processor(self) -> Callable | None: 76 | """Get the video processor for this modality. 77 | 78 | The callable should take a single video numpy array. 79 | """ 80 | return None 81 | 82 | def process(self, modality: Modality, data: npt.NDArray) -> dict[str, npt.NDArray]: 83 | """Process the input data for the given modality.""" 84 | match modality: 85 | case Modality.IMAGE: 86 | image_processor = self.get_image_processor() 87 | if image_processor is None: 88 | raise ValueError("Image processor not available.") 89 | return image_processor(data) 90 | case Modality.AUDIO: 91 | audio_processor = self.get_audio_processor() 92 | if audio_processor is None: 93 | raise ValueError("Audio processor not available.") 94 | return audio_processor(data) 95 | case Modality.VIDEO: 96 | video_processor = self.get_video_processor() 97 | if video_processor is None: 98 | raise ValueError("Video processor not available.") 99 | return video_processor(data) 100 | case _: 101 | raise ValueError(f"Unsupported modality: {modality}") 102 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """A collection of reusable layers.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/models/layers/activations.py: -------------------------------------------------------------------------------- 1 | """Activation functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class QuickGELU(nn.Module): 10 | # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 11 | def forward(self, x: torch.Tensor) -> torch.Tensor: 12 | return x * torch.sigmoid(1.702 * x) 13 | 14 | 15 | def get_act_fn(name: str) -> nn.Module: 16 | """Get an activation function by name.""" 17 | match name.lower(): 18 | case "gelu": 19 | return nn.GELU() 20 | case "quick_gelu": 21 | return QuickGELU() 22 | case "gelu_pytorch_tanh": 23 | return nn.GELU(approximate="tanh") 24 | case "relu": 25 | return nn.ReLU() 26 | case "silu": 27 | return nn.SiLU() 28 | case _: 29 | raise NotImplementedError(f"Activation function {name!r} is not implemented.") 30 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/models/layers/attention.py: -------------------------------------------------------------------------------- 1 | """Generic attention implementation for ViTs.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | import torch.nn as nn 7 | from xformers.ops import memory_efficient_attention_forward # type: ignore 8 | 9 | 10 | class Attention(nn.Module): 11 | """Full attention implementation for ViTs.""" 12 | 13 | def __init__( 14 | self, 15 | num_heads: int, 16 | head_size: int, 17 | scale: float, 18 | num_kv_heads: int | None = None, 19 | ) -> None: 20 | super().__init__() 21 | self.num_heads = num_heads 22 | self.head_size = head_size 23 | self.scale = scale 24 | self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads 25 | self.num_queries_per_kv, rem = divmod(self.num_heads, self.num_kv_heads) 26 | assert rem == 0, f"{num_heads=} must be divisible by {num_kv_heads=}" 27 | 28 | def forward( 29 | self, 30 | query: torch.Tensor, 31 | key: torch.Tensor, 32 | value: torch.Tensor, 33 | ) -> torch.Tensor: 34 | """Run forward. 35 | 36 | Args: 37 | query, key, value: [batch_size, seq_len, hidden_size] 38 | """ 39 | bsz, q_len, _ = query.size() 40 | kv_len = key.size(1) 41 | 42 | query = query.view(bsz, q_len, self.num_heads, self.head_size) 43 | key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) 44 | value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) 45 | 46 | if (num_repeat := self.num_queries_per_kv) > 1: 47 | # Handle MQA and GQA 48 | key = torch.repeat_interleave(key, num_repeat, dim=2) 49 | value = torch.repeat_interleave(value, num_repeat, dim=2) 50 | 51 | out = memory_efficient_attention_forward(query, key, value, scale=self.scale) 52 | 53 | return out.reshape(bsz, q_len, -1) 54 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/models/layers/vit.py: -------------------------------------------------------------------------------- 1 | """Vision Transformer.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch.nn as nn 6 | from transformers.models.siglip import SiglipVisionConfig 7 | 8 | from cornserve.task_executors.eric.models.layers.siglip import SiglipVisionModel 9 | 10 | 11 | def _get_num_hidden_layers(hf_config) -> int: 12 | """Determine the number of hidden layers to initialize up to in the 13 | visual encoder. 14 | 15 | Args: 16 | hf_config: Model config with vision feature layer(s). 17 | """ 18 | feature_layers = hf_config.vision_feature_layer 19 | num_hidden_layers = hf_config.vision_config.num_hidden_layers 20 | # If we have one feature layer, initialize up to that layer 21 | if isinstance(feature_layers, int): 22 | return _get_layer_index(feature_layers, num_hidden_layers) 23 | # If we have multiple feature layers, initialize up to the deepest one 24 | elif isinstance(feature_layers, (list, tuple)): 25 | return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers) 26 | raise TypeError(f"vision_layer_feature type: {type(feature_layers)} is not supported") 27 | 28 | 29 | def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: 30 | """Given a signed vision feature layer, get the number of hidden layers 31 | needed to leverage it. 32 | 33 | Args: 34 | feature_layer_index: Index of a required layer in the visual encoder. 35 | num_hidden_layers: The total number of hidden layers in the visual 36 | encoder. 37 | """ 38 | if feature_layer_index < 0: 39 | return num_hidden_layers + feature_layer_index + 1 40 | return feature_layer_index 41 | 42 | 43 | def init_vision_tower_for_llava( 44 | hf_config, 45 | *, 46 | require_post_norm: bool | None = None, 47 | ) -> nn.Module: 48 | """Initialize the vision tower for Llava-family models.""" 49 | vision_config = hf_config.vision_config 50 | 51 | # Initialize the vision tower only up to the deepest required feature layer 52 | num_hidden_layers = _get_num_hidden_layers(hf_config) 53 | 54 | if isinstance(vision_config, SiglipVisionConfig): 55 | return SiglipVisionModel( 56 | vision_config, 57 | num_hidden_layers_override=num_hidden_layers, 58 | require_post_norm=require_post_norm, 59 | ) 60 | 61 | msg = f"Unsupported vision config: {type(vision_config)}" 62 | raise NotImplementedError(msg) 63 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/router/__init__.py: -------------------------------------------------------------------------------- 1 | """The router receives embedding requests and invokes the engine.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/router/app.py: -------------------------------------------------------------------------------- 1 | """Eric FastAPI app definition.""" 2 | 3 | from __future__ import annotations 4 | 5 | import uuid 6 | 7 | from fastapi import APIRouter, FastAPI, Request, Response, status 8 | from opentelemetry import trace 9 | 10 | from cornserve.logging import get_logger 11 | from cornserve.task_executors.eric.api import EmbeddingRequest, EmbeddingResponse, Modality, Status 12 | from cornserve.task_executors.eric.config import EricConfig 13 | from cornserve.task_executors.eric.engine.client import EngineClient 14 | from cornserve.task_executors.eric.models.registry import MODEL_REGISTRY 15 | from cornserve.task_executors.eric.router.processor import Processor 16 | 17 | router = APIRouter() 18 | logger = get_logger(__name__) 19 | tracer = trace.get_tracer(__name__) 20 | 21 | 22 | @router.get("/health") 23 | async def health_check(request: Request) -> Response: 24 | """Checks whether the router and the engine are alive.""" 25 | engine_client: EngineClient = request.app.state.engine_client 26 | match engine_client.health_check(): 27 | case True: 28 | return Response(status_code=status.HTTP_200_OK) 29 | case False: 30 | return Response(status_code=status.HTTP_503_SERVICE_UNAVAILABLE) 31 | 32 | 33 | @router.get("/info") 34 | async def info(raw_request: Request) -> EricConfig: 35 | """Returns Eric's configuration information.""" 36 | return raw_request.app.state.config 37 | 38 | 39 | @router.get("/modalities") 40 | async def modalities(raw_request: Request) -> list[Modality]: 41 | """Return the list of modalities supported by this model.""" 42 | config: EricConfig = raw_request.app.state.config 43 | return list(MODEL_REGISTRY[config.model.hf_config.model_type].modality.keys()) 44 | 45 | 46 | @router.post("/embeddings") 47 | async def embeddings( 48 | request: EmbeddingRequest, 49 | raw_request: Request, 50 | raw_response: Response, 51 | ) -> EmbeddingResponse: 52 | """Handler for embedding requests.""" 53 | span = trace.get_current_span() 54 | for data_item in request.data: 55 | span.set_attribute( 56 | f"eric.embeddings.data.{data_item.id}.url", 57 | data_item.url, 58 | ) 59 | processor: Processor = raw_request.app.state.processor 60 | engine_client: EngineClient = raw_request.app.state.engine_client 61 | 62 | # Load data from URLs and apply processing 63 | processed = await processor.process(request.data) 64 | 65 | # Send to engine process (embedding + transmission via Tensor Sidecar) 66 | response = await engine_client.embed(uuid.uuid4().hex, processed) 67 | 68 | match response.status: 69 | case Status.SUCCESS: 70 | raw_response.status_code = status.HTTP_200_OK 71 | case Status.ERROR: 72 | raw_response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR 73 | case _: 74 | logger.error("Unexpected status: %s", response.status) 75 | raw_response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR 76 | return response 77 | 78 | 79 | def init_app_state(app: FastAPI, config: EricConfig) -> None: 80 | """Initialize the app state with the configuration and engine client.""" 81 | app.state.config = config 82 | app.state.processor = Processor(config.model.id, config.modality) 83 | app.state.engine_client = EngineClient(config) 84 | 85 | 86 | def create_app(config: EricConfig) -> FastAPI: 87 | """Create a FastAPI app with the given configuration.""" 88 | app = FastAPI() 89 | app.include_router(router) 90 | init_app_state(app, config) 91 | return app 92 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions and classes for Eric.""" 2 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """Utilities for distributed inference.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Sequence 6 | 7 | import torch 8 | 9 | 10 | def divide(numerator, denominator): 11 | """Ensure numerator is divisible by the denominator and return the quotient.""" 12 | quotient, remainder = divmod(numerator, denominator) 13 | assert remainder == 0, f"{numerator} is not divisible by {denominator}" 14 | return quotient 15 | 16 | 17 | def split_tensor_along_last_dim( 18 | tensor: torch.Tensor, 19 | num_partitions: int, 20 | contiguous_split_chunks: bool = False, 21 | ) -> Sequence[torch.Tensor]: 22 | """Split a tensor along its last dimension. 23 | 24 | Arguments: 25 | tensor: input tensor. 26 | num_partitions: number of partitions to split the tensor 27 | contiguous_split_chunks: If True, make each chunk contiguous. 28 | 29 | Returns: 30 | A list of Tensors 31 | """ 32 | # Get the size and dimension. 33 | last_dim = tensor.dim() - 1 34 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 35 | # Split. 36 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 37 | # NOTE: torch.split does not create contiguous tensors by default. 38 | if contiguous_split_chunks: 39 | return tuple(chunk.contiguous() for chunk in tensor_list) 40 | 41 | return tensor_list 42 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/utils/network.py: -------------------------------------------------------------------------------- 1 | """Utilities for networking.""" 2 | 3 | from __future__ import annotations 4 | 5 | import socket 6 | 7 | 8 | def get_open_port() -> int: 9 | """Get an open port on the local machine.""" 10 | # Try IPv4 first 11 | try: 12 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 13 | s.bind(("", 0)) 14 | return s.getsockname()[1] 15 | except OSError: 16 | # Try IPv6 17 | with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: 18 | s.bind(("", 0)) 19 | return s.getsockname()[1] 20 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/utils/package.py: -------------------------------------------------------------------------------- 1 | """Utilities for handling optional package imports.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | 8 | class PlaceholderModule: 9 | """A placeholder class that replaces optional packages when they are not installed.""" 10 | 11 | def __init__(self, name: str, optional_dependency_name: str) -> None: 12 | """Instantiate a placeholder module with the package's name.""" 13 | self.__name = name 14 | self.__optional_dependency_name = optional_dependency_name 15 | 16 | def __getattr__(self, name: str) -> Any: 17 | """Raise an error when any attribute is accessed.""" 18 | if name == "__name": 19 | return self.__name 20 | 21 | if name == "__optional_dependency_name": 22 | return self.__optional_dependency_name 23 | 24 | raise RuntimeError( 25 | f"Optional package '{self.__name}' is not installed. Please install " 26 | f"optional dependencies with `pip install cornserve[{self.__optional_dependency_name}]`." 27 | ) 28 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/utils/process.py: -------------------------------------------------------------------------------- 1 | """Utilities for process management.""" 2 | 3 | from __future__ import annotations 4 | 5 | import contextlib 6 | import os 7 | import signal 8 | 9 | import psutil 10 | 11 | 12 | def kill_process_tree(pid: int | None) -> None: 13 | """Kill all descendant processes of the given pid by sending SIGKILL. 14 | 15 | Args: 16 | pid: Process ID of the parent process. 17 | """ 18 | # None might be passed in if mp.Process hasn't been spawned yet 19 | if pid is None: 20 | return 21 | 22 | try: 23 | parent = psutil.Process(pid) 24 | except psutil.NoSuchProcess: 25 | return 26 | 27 | # Get all children recursively 28 | children = parent.children(recursive=True) 29 | 30 | # Send SIGKILL to all children first 31 | for child in children: 32 | with contextlib.suppress(ProcessLookupError): 33 | os.kill(child.pid, signal.SIGKILL) 34 | 35 | # Finally kill the parent 36 | with contextlib.suppress(ProcessLookupError): 37 | os.kill(pid, signal.SIGKILL) 38 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/utils/serde.py: -------------------------------------------------------------------------------- 1 | """Utilities for serializing and deserializing objects.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pickle 6 | from typing import Any 7 | 8 | import numpy as np 9 | import torch 10 | from msgspec import msgpack 11 | 12 | CUSTOM_TYPE_NUMPY = 1 13 | CUSTOM_TYPE_TORCH = 2 14 | CUSTOM_TYPE_PICKLE = 3 15 | 16 | 17 | class MsgpackEncoder: 18 | """Msgpack encoder that implements custom serialization.""" 19 | 20 | def __init__(self) -> None: 21 | """Initialize the encoder.""" 22 | self.encoder = msgpack.Encoder(enc_hook=enc_hook) 23 | 24 | def encode(self, obj: Any) -> bytes: 25 | """Encode an object to bytes.""" 26 | return self.encoder.encode(obj) 27 | 28 | def encode_into(self, obj: Any, buffer: bytearray) -> None: 29 | """Encode an object into a buffer.""" 30 | self.encoder.encode_into(obj, buffer) 31 | 32 | 33 | class MsgpackDecoder: 34 | """Msgpack decoder that implements custom deserialization.""" 35 | 36 | def __init__(self, ty: type | None = None) -> None: 37 | """Initialize the decoder.""" 38 | self.decoder = msgpack.Decoder(type=ty, ext_hook=ext_hook) 39 | 40 | def decode(self, data: bytes) -> Any: 41 | """Decode bytes to an object.""" 42 | return self.decoder.decode(data) 43 | 44 | 45 | def enc_hook(obj: Any) -> Any: 46 | """Use pickle to serialize Numpy arrays. 47 | 48 | https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 49 | """ 50 | if isinstance(obj, np.ndarray): 51 | return msgpack.Ext(CUSTOM_TYPE_NUMPY, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) 52 | if isinstance(obj, torch.Tensor): 53 | # Torch tensors are serialized as Numpy arrays. 54 | return msgpack.Ext( 55 | CUSTOM_TYPE_TORCH, 56 | pickle.dumps(obj.numpy(), protocol=pickle.HIGHEST_PROTOCOL), 57 | ) 58 | 59 | return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) 60 | 61 | 62 | def ext_hook(code: int, data: memoryview) -> Any: 63 | """Use pickle to deserialize Numpy arrays.""" 64 | if code == CUSTOM_TYPE_NUMPY: 65 | return pickle.loads(data) 66 | if code == CUSTOM_TYPE_TORCH: 67 | # Torch tensors are deserialized as Numpy arrays. 68 | return torch.from_numpy(pickle.loads(data)) 69 | if code == CUSTOM_TYPE_PICKLE: 70 | return pickle.loads(data) 71 | 72 | raise ValueError(f"Unknown custom serialization code: {code}") 73 | -------------------------------------------------------------------------------- /python/cornserve/task_executors/eric/utils/zmq.py: -------------------------------------------------------------------------------- 1 | """Utilities for creating and managing ZMQ sockets.""" 2 | 3 | from __future__ import annotations 4 | 5 | import contextlib 6 | import tempfile 7 | from collections.abc import Iterator 8 | from typing import overload 9 | from uuid import uuid4 10 | 11 | import zmq 12 | import zmq.asyncio 13 | 14 | from cornserve.logging import get_logger 15 | 16 | logger = get_logger(__name__) 17 | 18 | TMP_DIR = tempfile.gettempdir() 19 | 20 | 21 | @overload 22 | def make_zmq_socket( 23 | ctx: zmq.asyncio.Context, 24 | path: str, 25 | sock_type: int, 26 | ) -> zmq.asyncio.Socket: ... 27 | 28 | 29 | @overload 30 | def make_zmq_socket( 31 | ctx: zmq.Context, 32 | path: str, 33 | sock_type: int, 34 | ) -> zmq.Socket: ... 35 | 36 | 37 | def make_zmq_socket( 38 | ctx: zmq.Context | zmq.asyncio.Context, 39 | path: str, 40 | sock_type: int, 41 | ) -> zmq.Socket | zmq.asyncio.Socket: 42 | """Create a ZMQ socket. 43 | 44 | Args: 45 | ctx: The ZMQ context. Can be either a sync or async context. 46 | path: Socket path prefixed with protocol. 47 | sock_type: Socket type, like `zmq.PULL` or `zmq.PUSH`. 48 | """ 49 | s = ctx.socket(sock_type) 50 | 51 | buf_size = int(0.5 * 1024**3) # 500 MiB 52 | 53 | if sock_type == zmq.PULL: 54 | s.setsockopt(zmq.RCVHWM, 0) 55 | s.setsockopt(zmq.RCVBUF, buf_size) 56 | s.connect(path) 57 | elif sock_type == zmq.PUSH: 58 | s.setsockopt(zmq.SNDHWM, 0) 59 | s.setsockopt(zmq.SNDBUF, buf_size) 60 | s.bind(path) 61 | else: 62 | raise ValueError(f"Unsupported socket type: {sock_type}") 63 | 64 | return s 65 | 66 | 67 | def get_open_zmq_ipc_path(description: str | None = None) -> str: 68 | """Get an open IPC path for ZMQ sockets. 69 | 70 | Args: 71 | description: An optional string description for where the socket is used. 72 | """ 73 | filename = f"{description}-{uuid4()}" if description is not None else str(uuid4()) 74 | return f"ipc://{TMP_DIR}/{filename}" 75 | 76 | 77 | @contextlib.contextmanager 78 | def zmq_sync_socket(path: str, sock_type: int) -> Iterator[zmq.Socket]: 79 | """Context manager that creates and cleans up a ZMQ socket.""" 80 | ctx = zmq.Context(io_threads=2) 81 | try: 82 | yield make_zmq_socket(ctx, path, sock_type) 83 | 84 | finally: 85 | ctx.destroy(linger=0) 86 | -------------------------------------------------------------------------------- /python/cornserve/tracing.py: -------------------------------------------------------------------------------- 1 | """OpenTelemetry configuration for the cornserve services.""" 2 | 3 | from __future__ import annotations 4 | 5 | from opentelemetry import trace 6 | from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter 7 | from opentelemetry.sdk.resources import Resource 8 | from opentelemetry.sdk.trace import TracerProvider 9 | from opentelemetry.sdk.trace.export import BatchSpanProcessor 10 | 11 | from cornserve.constants import K8S_OTEL_GRPC_URL 12 | 13 | 14 | def configure_otel(name: str) -> None: 15 | """Configure OpenTelemetry for the given service name.""" 16 | resource = Resource.create({"service.name": name}) 17 | provider = TracerProvider(resource=resource) 18 | exporter = OTLPSpanExporter(endpoint=K8S_OTEL_GRPC_URL) 19 | processor = BatchSpanProcessor(exporter) 20 | provider.add_span_processor(processor) 21 | trace.set_tracer_provider(provider) 22 | -------------------------------------------------------------------------------- /python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "cornserve" 7 | description = "Easy, fast, and scalable multimodal agentic AI" 8 | authors = [ 9 | { name = "Cornserve Team" }, 10 | ] 11 | readme = "README.md" 12 | license = { file = "LICENSE" } 13 | classifiers = [ 14 | "Environment :: GPU :: NVIDIA CUDA", 15 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 16 | "Programming Language :: Python :: 3.11", 17 | "Programming Language :: Python :: 3.12", 18 | "Programming Language :: Python :: 3.13", 19 | ] 20 | requires-python = ">=3.11" 21 | dependencies = [ 22 | "grpcio-tools==1.71.0", 23 | "rich", 24 | "requests", 25 | "tyro", 26 | "kubernetes_asyncio", 27 | "httpx", 28 | "pydantic>=2.11", 29 | "opentelemetry-api", 30 | "opentelemetry-sdk", 31 | "opentelemetry-exporter-otlp-proto-grpc", 32 | "websocket-client", 33 | ] 34 | dynamic = ["version"] 35 | 36 | [project.scripts] 37 | cornserve = "cornserve.cli:main" 38 | 39 | [project.optional-dependencies] 40 | sidecar-api = [ 41 | "torch>=2.5.0", 42 | "opentelemetry-instrumentation-grpc", 43 | "opentelemetry-instrumentation-threading", 44 | ] 45 | sidecar = [ 46 | "torch>=2.5.0", 47 | "ucxx-cu12", 48 | "msgspec", 49 | "opentelemetry-instrumentation-grpc", 50 | ] 51 | gateway = [ 52 | "fastapi", 53 | "uvicorn[standard]", 54 | "opentelemetry-instrumentation-fastapi", 55 | "opentelemetry-instrumentation-grpc", 56 | "opentelemetry-instrumentation-httpx", 57 | "websocket-client", 58 | ] 59 | resource-manager = [ 60 | "opentelemetry-instrumentation-grpc", 61 | ] 62 | task-manager = [] 63 | task-dispatcher = [ 64 | "fastapi", 65 | "uvicorn[standard]", 66 | "opentelemetry-instrumentation-fastapi", 67 | "opentelemetry-instrumentation-httpx", 68 | "opentelemetry-instrumentation-grpc", 69 | ] 70 | audio = ["librosa"] 71 | eric-no-gpu = [ 72 | "fastapi", 73 | "uvicorn[standard]", 74 | "pyzmq", 75 | "msgspec", 76 | "psutil", 77 | "torch>=2.5.0", 78 | "transformers", 79 | "huggingface_hub", 80 | "pillow", 81 | "opencv-python-headless", 82 | "einops", 83 | "cornserve[sidecar-api,audio]", 84 | "opentelemetry-instrumentation-fastapi", 85 | "opentelemetry-instrumentation-threading", 86 | ] 87 | eric = ["flash-attn", "xformers", "cornserve[eric-no-gpu]"] 88 | eric-audio = ["cornserve[eric,audio]"] 89 | dev-common = [ 90 | "grpcio-tools", 91 | "pyright!=1.1.401", 92 | "ruff", 93 | "pytest", 94 | "pytest-asyncio", 95 | "pytest-dependency", 96 | "cornserve[sidecar-api,gateway,resource-manager,task-manager,task-dispatcher]", 97 | ] 98 | dev = ["cornserve[dev-common,sidecar,eric]"] 99 | dev-no-gpu = ["cornserve[dev-common,eric-no-gpu]"] 100 | 101 | [tool.setuptools.dynamic] 102 | version = { attr = "cornserve.__version__" } 103 | 104 | [tool.ruff] 105 | line-length = 120 106 | 107 | [tool.ruff.lint] 108 | pydocstyle.convention = "google" 109 | select = [ 110 | "E", # pycodestyle error 111 | "F", # pyflakes 112 | "D", # pydocstyle 113 | "PL", # pylint 114 | "N", # pep8-naming 115 | "UP", # pyupgrade 116 | "B", # flake8-bugbear (detects likely bugs) 117 | "G", # flake8-logging-format (complains about logging) 118 | "SIM", # flake8-simplify (suggests code simplifications) 119 | ] 120 | exclude = [ 121 | "cornserve/task_executors/eric/models/*.py", 122 | ] 123 | ignore = [ 124 | "PLW0603", # Global statement 125 | "PLR2004", # Magic value 126 | "PLR0912", # Too many branches 127 | "PLR0913", # Too many arguments 128 | "PLR0915", # Too many statements 129 | "PLR0402", # `import torch.nn as nn` is fine 130 | ] 131 | 132 | [tool.ruff.lint.per-file-ignores] 133 | "**/__init__.py" = ["F401", "F403"] 134 | "cornserve/services/**/server.py" = ["N802"] 135 | "cornserve/services/**/grpc.py" = ["N802"] 136 | 137 | [tool.pyright] 138 | exclude = [ 139 | "**/*_pb2.py", 140 | "**/*_pb2_grpc.py", 141 | "cornserve/task_executors/eric/models/*.py", 142 | ] 143 | 144 | [tool.pytest.ini_options] 145 | addopts = "-v" 146 | asyncio_mode = "strict" 147 | asyncio_default_fixture_loop_scope = "function" 148 | 149 | [tool.ruff.lint.isort] 150 | known-first-party = ["cornserve"] 151 | -------------------------------------------------------------------------------- /python/scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -ev 4 | 5 | echo ${BASH_SOURCE[0]} 6 | 7 | cd "$(dirname "${BASH_SOURCE[0]}")/.." 8 | 9 | if [[ -z $GITHUB_ACTION ]]; then 10 | ruff format --target-version py311 cornserve tests 11 | ruff check --fix-only --select I cornserve tests 12 | else 13 | ruff format --target-version py311 --check cornserve tests 14 | ruff check --select I cornserve tests 15 | fi 16 | 17 | ruff check --target-version py311 cornserve 18 | pyright cornserve 19 | -------------------------------------------------------------------------------- /python/tests/services/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/services/__init__.py -------------------------------------------------------------------------------- /python/tests/services/sidecar/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/services/sidecar/__init__.py -------------------------------------------------------------------------------- /python/tests/services/sidecar/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import multiprocessing 3 | import os 4 | import signal 5 | import time 6 | 7 | import grpc 8 | import pytest 9 | import torch 10 | 11 | from cornserve.services.pb import sidecar_pb2, sidecar_pb2_grpc 12 | from cornserve.sidecar.constants import grpc_url_from_rank 13 | 14 | 15 | def run_server(rank: int, world_size: int, local_peer_ranks: list[int], shm_size: int) -> None: 16 | """Sidecar server entrypoint that will run in a subprocess.""" 17 | mock_device() 18 | 19 | # Set environment variables 20 | os.environ["SIDECAR_RANK"] = str(rank) 21 | os.environ["SIDECAR_WORLD_SIZE"] = str(world_size) 22 | os.environ["SIDECAR_LOCAL_PEER_RANKS"] = ",".join(map(str, local_peer_ranks)) 23 | os.environ["SIDECAR_SHM_SIZE"] = str(shm_size) 24 | os.environ["SIDECAR_IS_LOCAL"] = "true" 25 | 26 | from cornserve.services.sidecar.server import cleanup_coroutines, main 27 | 28 | loop = asyncio.new_event_loop() 29 | asyncio.set_event_loop(loop) 30 | 31 | try: 32 | loop.run_until_complete(main()) 33 | except KeyboardInterrupt: 34 | pass 35 | finally: 36 | loop.run_until_complete(loop.shutdown_asyncgens()) 37 | loop.run_until_complete(asyncio.gather(*cleanup_coroutines)) 38 | loop.close() 39 | 40 | 41 | def device_from_rank(rank: int) -> torch.device: 42 | """Get the device for a given rank.""" 43 | if torch.cuda.is_available(): 44 | return torch.device(f"cuda:{rank % torch.cuda.device_count()}") 45 | return torch.device("cpu") 46 | 47 | 48 | def mock_device() -> None: 49 | mocker = pytest.MonkeyPatch() 50 | mocker.setattr( 51 | "cornserve.sidecar.utils.device_from_rank", 52 | device_from_rank, 53 | ) 54 | 55 | 56 | def start_sidecar_servers( 57 | n: int = 4, 58 | cluster_size: int = 2, 59 | shm_size: int = 2 << 28, 60 | ) -> list[multiprocessing.Process]: 61 | """Start n sidecar servers in n processes.""" 62 | processes = [] 63 | ctx = multiprocessing.get_context("spawn") 64 | for rank in range(n): 65 | cluster_start = (rank // cluster_size) * cluster_size 66 | cluster_ranks = list(range(cluster_start, cluster_start + cluster_size)) 67 | print("Starting sidecar server of rank", rank, "with cluster ranks", cluster_ranks) 68 | process = ctx.Process( 69 | target=run_server, 70 | args=(rank, n, cluster_ranks, shm_size), 71 | ) 72 | process.start() 73 | processes.append(process) 74 | return processes 75 | 76 | 77 | def server_is_online(stub: sidecar_pb2_grpc.SidecarStub) -> bool: 78 | """Check if the server is running.""" 79 | try: 80 | req = sidecar_pb2.CheckHealthRequest() 81 | res = stub.CheckHealth(req) 82 | return res.status == sidecar_pb2.HealthStatus.HEALTH_ALL_GOOD 83 | except grpc.RpcError: 84 | return False 85 | 86 | 87 | def wait_for_servers_to_start(rank: int) -> None: 88 | while True: 89 | with grpc.insecure_channel(grpc_url_from_rank(rank)) as channel: 90 | stub = sidecar_pb2_grpc.SidecarStub(channel) 91 | if server_is_online(stub): 92 | break 93 | else: 94 | time.sleep(10.2) 95 | 96 | 97 | def terminate_processes(processes: list[multiprocessing.Process]) -> None: 98 | """Terminate all processes.""" 99 | 100 | for process in processes: 101 | if process.pid: 102 | os.kill(process.pid, signal.SIGINT) 103 | 104 | for process in processes: 105 | process.join(timeout=5) 106 | if not process.is_alive(): 107 | continue 108 | process.terminate() 109 | process.join(timeout=2) 110 | if process.is_alive(): 111 | process.kill() 112 | process.join() 113 | -------------------------------------------------------------------------------- /python/tests/task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task/__init__.py -------------------------------------------------------------------------------- /python/tests/task/builtins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task/builtins/__init__.py -------------------------------------------------------------------------------- /python/tests/task/builtins/test_mllm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | 5 | import pytest 6 | 7 | from cornserve.task.base import TaskContext, TaskInvocation, task_context 8 | from cornserve.task.builtins.mllm import MLLMInput, MLLMTask, Modality 9 | from cornserve.task.forward import DataForward, ForwardableType, Tensor 10 | 11 | 12 | def test_mllm_record(): 13 | """Test MLLM task invocation recording.""" 14 | task = MLLMTask(model_id="llava", modalities=[Modality.IMAGE]) 15 | task_input = MLLMInput(prompt="Hello, world!", multimodal_data=[("image", "http://example.com/image.jpg")]) 16 | 17 | ctx = TaskContext(task_id="mllm-test") 18 | task_context.set(ctx) 19 | with ctx.record(): 20 | task_output = task.invoke(task_input) 21 | 22 | assert isinstance(task_output.response, str) 23 | assert task_output.response == "" 24 | 25 | assert len(ctx.invocations) == 2 26 | assert ctx.invocations[0].task == task.image_encoder 27 | assert ctx.invocations[0].task_input.data_urls == ["http://example.com/image.jpg"] 28 | assert len(ctx.invocations[0].task_output.embeddings) == 1 29 | assert ( 30 | ctx.invocations[0].task_output.embeddings[0].data_type 31 | == DataForward[Tensor]().data_type 32 | == ForwardableType.TENSOR 33 | ) 34 | assert ctx.invocations[1].task_input.prompt == "Hello, world!" 35 | assert ctx.invocations[0].task_output.embeddings[0] == ctx.invocations[1].task_input.embeddings[0] 36 | 37 | 38 | @pytest.mark.asyncio 39 | async def test_mllm_record_concurrent(): 40 | """Test multiple concurrent MLLM task invocations.""" 41 | 42 | task = MLLMTask(model_id="llava", modalities=[Modality.IMAGE, Modality.VIDEO]) 43 | task_input = MLLMInput( 44 | prompt="Hello, world!", 45 | multimodal_data=[("image", "http://example.com/image.jpg"), ("video", "http://example.com/video.mp4")], 46 | ) 47 | 48 | async def call(task: MLLMTask, task_input: MLLMInput) -> list[TaskInvocation]: 49 | task_context.set(TaskContext(task_id=task.id)) 50 | return await asyncio.create_task(call_impl(task, task_input)) 51 | 52 | async def call_impl(task: MLLMTask, task_input: MLLMInput) -> list[TaskInvocation]: 53 | ctx = task_context.get() 54 | 55 | with ctx.record(): 56 | _ = task.invoke(task_input) 57 | 58 | return ctx.invocations 59 | 60 | invocations1, invocations2 = await asyncio.gather( 61 | call(task, task_input), 62 | call(task, task_input), 63 | ) 64 | 65 | assert len(invocations1) == 3 66 | assert len(invocations2) == 3 67 | 68 | assert invocations1[0].task == task.image_encoder 69 | assert invocations1[0].task_input.data_urls == ["http://example.com/image.jpg"] 70 | assert invocations1[1].task == task.video_encoder 71 | assert invocations1[1].task_input.data_urls == ["http://example.com/video.mp4"] 72 | assert invocations1[2].task == task.llm 73 | 74 | assert invocations2[0].task == task.image_encoder 75 | assert invocations2[0].task_input.data_urls == ["http://example.com/image.jpg"] 76 | assert invocations2[1].task == task.video_encoder 77 | assert invocations2[1].task_input.data_urls == ["http://example.com/video.mp4"] 78 | assert invocations2[2].task == task.llm 79 | -------------------------------------------------------------------------------- /python/tests/task/test_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from cornserve.task.base import TaskGraphDispatch, TaskInvocation 4 | from cornserve.task.builtins.encoder import EncoderInput, EncoderOutput, EncoderTask, Modality 5 | from cornserve.task.builtins.llm import LLMBaseTask, LLMForwardOutputTask, LLMInput, LLMOutput, LLMTask 6 | from cornserve.task.forward import DataForward, Tensor 7 | 8 | 9 | def test_root_unit_task_cls(): 10 | """Tests whether the root unit task class is figured out correctly.""" 11 | assert LLMTask.root_unit_task_cls is LLMBaseTask 12 | assert LLMForwardOutputTask.root_unit_task_cls is LLMBaseTask 13 | assert EncoderTask.root_unit_task_cls is EncoderTask 14 | 15 | 16 | def test_serde_one(): 17 | """Tests whether unit tasks can be serialized and deserialized.""" 18 | invocation = TaskInvocation( 19 | task=LLMTask(model_id="llama"), 20 | task_input=LLMInput(prompt="Hello", multimodal_data=[]), 21 | task_output=LLMOutput(response="Hello"), 22 | ) 23 | invocation_json = invocation.model_dump_json() 24 | 25 | invocation_deserialized = TaskInvocation.model_validate_json(invocation_json) 26 | assert invocation == invocation_deserialized 27 | 28 | 29 | def test_serde_graph(): 30 | """Tests whether task graph invocations can be serialized and deserialized.""" 31 | encoder_invocation = TaskInvocation( 32 | task=EncoderTask(model_id="clip", modality=Modality.IMAGE), 33 | task_input=EncoderInput(data_urls=["https://example.com/image.jpg"]), 34 | task_output=EncoderOutput(embeddings=[DataForward[Tensor]()]), 35 | ) 36 | llm_invocation = TaskInvocation( 37 | task=LLMTask(model_id="llama"), 38 | task_input=LLMInput(prompt="Hello", multimodal_data=[("image", "https://example.com/image.jpg")]), 39 | task_output=LLMOutput(response="Hello"), 40 | ) 41 | graph = TaskGraphDispatch( 42 | task_id="test-graph", 43 | invocations=[encoder_invocation, llm_invocation], 44 | ) 45 | graph_json = graph.model_dump_json() 46 | 47 | graph_deserialized = TaskGraphDispatch.model_validate_json(graph_json) 48 | assert graph == graph_deserialized 49 | 50 | 51 | def test_task_equivalence(): 52 | """Tests whether unit task equivalence is determined correctly.""" 53 | assert LLMTask(model_id="llama").is_equivalent_to(LLMTask(model_id="llama")) 54 | assert not LLMTask(model_id="llama").is_equivalent_to(LLMTask(model_id="mistral")) 55 | assert EncoderTask(model_id="clip", modality=Modality.IMAGE).is_equivalent_to( 56 | EncoderTask(model_id="clip", modality=Modality.IMAGE) 57 | ) 58 | assert not EncoderTask(model_id="clip", modality=Modality.IMAGE).is_equivalent_to( 59 | EncoderTask(model_id="clip", modality=Modality.VIDEO) 60 | ) 61 | -------------------------------------------------------------------------------- /python/tests/task/test_registry.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from cornserve.task.registry import TASK_REGISTRY 6 | 7 | 8 | def test_task_registry(): 9 | """Tests whether the task registry is initialized correctly.""" 10 | llm_task = TASK_REGISTRY.get("LLMTask") 11 | llm_forward_output_task = TASK_REGISTRY.get("LLMForwardOutputTask") 12 | encoder_task = TASK_REGISTRY.get("EncoderTask") 13 | 14 | from cornserve.task.builtins.encoder import EncoderInput, EncoderOutput, EncoderTask 15 | from cornserve.task.builtins.llm import LLMForwardOutput, LLMForwardOutputTask, LLMInput, LLMOutput, LLMTask 16 | 17 | assert llm_task == (LLMTask, LLMInput, LLMOutput) 18 | assert llm_forward_output_task == (LLMForwardOutputTask, LLMInput, LLMForwardOutput) 19 | assert encoder_task == (EncoderTask, EncoderInput, EncoderOutput) 20 | 21 | assert "_NonExistentTask" not in TASK_REGISTRY 22 | with pytest.raises(KeyError): 23 | TASK_REGISTRY.get("_NonEistentTask") 24 | -------------------------------------------------------------------------------- /python/tests/task_executors/eric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task_executors/eric/__init__.py -------------------------------------------------------------------------------- /python/tests/task_executors/eric/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task_executors/eric/engine/__init__.py -------------------------------------------------------------------------------- /python/tests/task_executors/eric/engine/test_scheduler.py: -------------------------------------------------------------------------------- 1 | from cornserve.task_executors.eric.engine.scheduler import Scheduler 2 | from cornserve.task_executors.eric.schema import EngineEnqueueRequest, Modality, ProcessedEmbeddingData 3 | 4 | 5 | def test_mixed_modality(): 6 | """Batches should only hvae a single modality.""" 7 | scheduler = Scheduler() 8 | 9 | scheduler.enqueue( 10 | EngineEnqueueRequest( 11 | request_id="1", 12 | data=[ 13 | ProcessedEmbeddingData(id="im1", modality=Modality.IMAGE, data={}), 14 | ProcessedEmbeddingData(id="im2", modality=Modality.IMAGE, data={}), 15 | ], 16 | ) 17 | ) 18 | scheduler.enqueue( 19 | EngineEnqueueRequest( 20 | request_id="2", 21 | data=[ 22 | ProcessedEmbeddingData(id="vid1", modality=Modality.VIDEO, data={}), 23 | ProcessedEmbeddingData(id="im3", modality=Modality.IMAGE, data={}), 24 | ], 25 | ) 26 | ) 27 | scheduler.enqueue( 28 | EngineEnqueueRequest( 29 | request_id="3", 30 | data=[ 31 | ProcessedEmbeddingData(id="vid2", modality=Modality.VIDEO, data={}), 32 | ProcessedEmbeddingData(id="vid3", modality=Modality.VIDEO, data={}), 33 | ], 34 | ) 35 | ) 36 | 37 | assert scheduler.has_waiting_requests() 38 | batch = scheduler.schedule() 39 | assert batch.modality == Modality.IMAGE 40 | assert len(batch.request_ids) == 2 41 | scheduler.process_batch_result(request_ids=["1", "1"], data_ids=["im1", "im2"]) 42 | 43 | assert scheduler.has_waiting_requests() 44 | batch = scheduler.schedule() 45 | assert batch.modality == Modality.VIDEO 46 | assert len(batch.request_ids) == 1 47 | scheduler.process_batch_result(request_ids=["2"], data_ids=["vid1"]) 48 | 49 | assert scheduler.has_waiting_requests() 50 | batch = scheduler.schedule() 51 | assert batch.modality == Modality.IMAGE 52 | assert len(batch.request_ids) == 1 53 | scheduler.process_batch_result(request_ids=["2"], data_ids=["im3"]) 54 | 55 | assert scheduler.has_waiting_requests() 56 | batch = scheduler.schedule() 57 | assert batch.modality == Modality.VIDEO 58 | assert len(batch.request_ids) == 2 59 | scheduler.process_batch_result(request_ids=["3", "3"], data_ids=["vid2", "vid3"]) 60 | 61 | assert not scheduler.has_waiting_requests() 62 | -------------------------------------------------------------------------------- /python/tests/task_executors/eric/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornserve-ai/cornserve/ff7948a3da7d9f6b8c36dbc6a773f36c9b0b1707/python/tests/task_executors/eric/models/__init__.py -------------------------------------------------------------------------------- /python/tests/task_executors/eric/models/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | from collections.abc import Generator 5 | 6 | import pytest 7 | 8 | from cornserve.task_executors.eric.schema import Modality 9 | 10 | from ..utils import ModalityData 11 | 12 | TEST_IMAGE_URLS = [ 13 | "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg", 14 | "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", 15 | ] 16 | 17 | 18 | @pytest.fixture(scope="session") 19 | def test_images() -> list[ModalityData]: 20 | """Fixture to provide test images.""" 21 | return [ModalityData(url=url, modality=Modality.IMAGE) for url in TEST_IMAGE_URLS] 22 | 23 | 24 | TEST_VIDEO_URLS = [ 25 | "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", 26 | "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ElephantsDream.mp4", 27 | "https://www.sample-videos.com/video321/mp4/360/big_buck_bunny_360p_2mb.mp4", 28 | "https://www.sample-videos.com/video321/mp4/720/big_buck_bunny_720p_2mb.mp4", 29 | ] 30 | 31 | 32 | @pytest.fixture(scope="session") 33 | def test_videos() -> list[ModalityData]: 34 | """Fixture to provide test videos.""" 35 | return [ModalityData(url=url, modality=Modality.VIDEO) for url in TEST_VIDEO_URLS] 36 | 37 | 38 | TEST_AUDIO_URLS = [ 39 | "https://s3.amazonaws.com/citizen-dj-assets.labs.loc.gov/audio/samplepacks/loc-fma/The-Call-of-the-Polar-Star_fma-115766_001_00-00-01.wav", 40 | "https://s3.amazonaws.com/citizen-dj-assets.labs.loc.gov/audio/samplepacks/loc-fma/Frog-In-The-Well_fma-39182_001_00-00-06.wav", 41 | "https://s3.amazonaws.com/citizen-dj-assets.labs.loc.gov/audio/samplepacks/loc-fma/Free-To-Use-13_fma-152622_004_00-02-55.wav", 42 | ] 43 | 44 | 45 | @pytest.fixture(scope="session") 46 | def test_audios() -> list[ModalityData]: 47 | """Fixture to provide test audios.""" 48 | return [ModalityData(url=url, modality=Modality.AUDIO) for url in TEST_AUDIO_URLS] 49 | 50 | 51 | @pytest.fixture(scope="session") 52 | def dump_tensors() -> Generator[str, None, None]: 53 | """Fixture to set `CORNSERVE_TEST_DUMP_TENSOR_DIR` environment variable.""" 54 | dir = os.getenv("CORNSERVE_TEST_DUMP_TENSOR_DIR") 55 | 56 | # If unset, do it in a tempdir and clean up after the test session. 57 | if dir is None: 58 | tmp = os.environ["CORNSERVE_TEST_DUMP_TENSOR_DIR"] = tempfile.mkdtemp() 59 | yield tmp 60 | del os.environ["CORNSERVE_TEST_DUMP_TENSOR_DIR"] 61 | shutil.rmtree(tmp) 62 | 63 | # If explicitly set, use it and leave the stuff because the user would have intended to keep it. 64 | else: 65 | yield dir 66 | -------------------------------------------------------------------------------- /python/tests/task_executors/eric/models/test_gemma3.py: -------------------------------------------------------------------------------- 1 | """Tests for the Gemma3 model's vision encoder.""" 2 | 3 | import pytest 4 | import torch 5 | from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration 6 | 7 | from cornserve.task_executors.eric.distributed.parallel import destroy_distributed, init_distributed 8 | from cornserve.task_executors.eric.executor.executor import ModelExecutor 9 | from cornserve.task_executors.eric.executor.loader import load_model 10 | from cornserve.task_executors.eric.models.registry import MODEL_REGISTRY 11 | from cornserve.task_executors.eric.schema import Status 12 | 13 | from ..utils import ( 14 | TP_SIZES, 15 | ModalityData, 16 | assert_same_weights, 17 | assert_similar, 18 | batch_builder, 19 | depends_on, 20 | param_tp_size, 21 | ) 22 | 23 | model_id = "google/gemma-3-4b-it" 24 | model_shorthand = "gemma3" 25 | 26 | 27 | def test_weight_loading() -> None: 28 | """Check if weights are loaded correctly.""" 29 | # Hugging Face model output 30 | hf_model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto").model 31 | 32 | # Load our model 33 | init_distributed(world_size=1, rank=0) 34 | our_model = load_model(model_id, torch_device=torch.device("cpu")) 35 | destroy_distributed() 36 | 37 | def check_qkv_proj_weight( 38 | our_name: str, 39 | our_param: torch.Tensor, 40 | hf_params: dict[str, torch.Tensor], 41 | ): 42 | """Check if the qkv_proj weights are the same.""" 43 | separate_weights = [] 44 | for key in ["q_proj", "k_proj", "v_proj"]: 45 | separate_weights.append(hf_params[our_name.replace("qkv_proj", key)]) 46 | assert torch.allclose(our_param, torch.cat(separate_weights, dim=0)) 47 | 48 | # Check if parameters are the same 49 | assert_same_weights( 50 | hf_model, 51 | our_model, 52 | required_prefixes=MODEL_REGISTRY[hf_model.config.model_type].weight.required_prefixes, 53 | ignored_prefixes=MODEL_REGISTRY[hf_model.config.model_type].weight.ignored_prefixes, 54 | transformed_weights={ 55 | "*qkv_proj.weight": check_qkv_proj_weight, 56 | "*qkv_proj.bias": check_qkv_proj_weight, 57 | }, 58 | ) 59 | 60 | 61 | @param_tp_size 62 | def test_image_inference(test_images: list[ModalityData], tp_size: int, dump_tensors: str) -> None: 63 | """Test if inference works correctly.""" 64 | executor = ModelExecutor(model_id=model_id, tp_size=tp_size, sender_sidecar_ranks=None) 65 | 66 | result = executor.execute_model(batch=batch_builder(model_id, model_shorthand, test_images)) 67 | 68 | assert result.status == Status.SUCCESS 69 | 70 | executor.shutdown() 71 | 72 | 73 | @depends_on("test_image_inference") 74 | def test_hf_reference(test_images: list[ModalityData], dump_tensors: str) -> None: 75 | """Generate reference outputs from the Hugging Face model.""" 76 | torch.set_grad_enabled(False) 77 | 78 | hf_model = Gemma3ForConditionalGeneration.from_pretrained( 79 | model_id, 80 | torch_dtype="auto", 81 | attn_implementation="flash_attention_2", 82 | ) 83 | model = hf_model.model.cuda().eval() 84 | 85 | image1 = test_images[0].processed(model_id) 86 | pixel_values = torch.asarray(image1["pixel_values"]).cuda() 87 | output1 = model.get_image_features(pixel_values=pixel_values).cpu() 88 | 89 | image2 = test_images[1].processed(model_id) 90 | pixel_values = torch.asarray(image2["pixel_values"]).cuda() 91 | output2 = model.get_image_features(pixel_values=pixel_values).cpu() 92 | 93 | for tp_degree in TP_SIZES: 94 | output = torch.load(f"{dump_tensors}/{model_shorthand}-image-tp{tp_degree}.pt") 95 | assert_similar([output1, output2], output) 96 | 97 | del output1, output2 98 | 99 | 100 | pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") 101 | -------------------------------------------------------------------------------- /python/tests/task_executors/eric/models/test_qwen2_vl.py: -------------------------------------------------------------------------------- 1 | """Tests for the Qwen2-VL model's vision encoder.""" 2 | 3 | import torch 4 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration 5 | 6 | from cornserve.task_executors.eric.distributed.parallel import destroy_distributed, init_distributed 7 | from cornserve.task_executors.eric.executor.executor import ModelExecutor 8 | from cornserve.task_executors.eric.executor.loader import load_model 9 | from cornserve.task_executors.eric.schema import Status 10 | 11 | from ..utils import ( 12 | TP_SIZES, 13 | ModalityData, 14 | assert_same_weights, 15 | assert_similar, 16 | batch_builder, 17 | depends_on, 18 | param_tp_size, 19 | ) 20 | 21 | model_id = "Qwen/Qwen2-VL-7B-Instruct" 22 | model_shorthand = "qwen2" 23 | 24 | 25 | def test_weight_loading() -> None: 26 | """Check if weights are loaded correctly.""" 27 | # Hugging Face model output 28 | hf_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto").visual 29 | 30 | # Load our model 31 | init_distributed(world_size=1, rank=0) 32 | our_model = load_model(model_id, torch_device=torch.device("cpu")) 33 | destroy_distributed() 34 | 35 | # Check if parameters are the same 36 | assert_same_weights(hf_model, our_model) 37 | 38 | 39 | @param_tp_size 40 | def test_image_inference(test_images: list[ModalityData], tp_size: int, dump_tensors: str) -> None: 41 | """Test if inference works correctly.""" 42 | executor = ModelExecutor(model_id=model_id, tp_size=tp_size, sender_sidecar_ranks=None) 43 | 44 | result = executor.execute_model(batch=batch_builder(model_id, model_shorthand, test_images)) 45 | 46 | assert result.status == Status.SUCCESS 47 | 48 | executor.shutdown() 49 | 50 | 51 | @param_tp_size 52 | def test_video_inference(test_videos: list[ModalityData], tp_size: int, dump_tensors: str) -> None: 53 | """Test if inference works correctly.""" 54 | executor = ModelExecutor(model_id=model_id, tp_size=tp_size, sender_sidecar_ranks=None) 55 | 56 | result = executor.execute_model(batch=batch_builder(model_id, model_shorthand, test_videos[:2])) 57 | 58 | assert result.status == Status.SUCCESS 59 | 60 | executor.shutdown() 61 | 62 | 63 | @depends_on("test_image_inference", "test_video_inference") 64 | def test_hf_reference(test_images: list[ModalityData], test_videos: list[ModalityData], dump_tensors: str) -> None: 65 | """Generate reference outputs from the Hugging Face model.""" 66 | torch.set_grad_enabled(False) 67 | 68 | hf_model = Qwen2VLForConditionalGeneration.from_pretrained( 69 | model_id, 70 | torch_dtype="auto", 71 | attn_implementation="flash_attention_2", 72 | ) 73 | model = hf_model.model.cuda().eval() 74 | 75 | image1 = test_images[0].processed(model_id) 76 | pixel_values = torch.asarray(image1["pixel_values"]).cuda() 77 | image_grid_thw = torch.asarray(image1["image_grid_thw"]).cuda() 78 | output1 = model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw).cpu() 79 | 80 | image2 = test_images[1].processed(model_id) 81 | pixel_values = torch.asarray(image2["pixel_values"]).cuda() 82 | image_grid_thw = torch.asarray(image2["image_grid_thw"]).cuda() 83 | output2 = model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw).cpu() 84 | 85 | for tp_degree in TP_SIZES: 86 | output = torch.load(f"{dump_tensors}/{model_shorthand}-image-tp{tp_degree}.pt") 87 | assert_similar([output1, output2], output) 88 | 89 | del output1, output2 90 | 91 | video1 = test_videos[0].processed(model_id) 92 | pixel_values_video = torch.asarray(video1["pixel_values_videos"]).cuda() 93 | video_grid_thw = torch.asarray(video1["video_grid_thw"]).cuda() 94 | output1 = model.get_video_features(pixel_values_videos=pixel_values_video, video_grid_thw=video_grid_thw).cpu() 95 | 96 | video2 = test_videos[1].processed(model_id) 97 | pixel_values_video2 = torch.asarray(video2["pixel_values_videos"]).cuda() 98 | video_grid_thw2 = torch.asarray(video2["video_grid_thw"]).cuda() 99 | output2 = model.get_video_features(pixel_values_videos=pixel_values_video2, video_grid_thw=video_grid_thw2).cpu() 100 | 101 | for tp_degree in TP_SIZES: 102 | output = torch.load(f"{dump_tensors}/{model_shorthand}-video-tp{tp_degree}.pt") 103 | assert_similar([output1, output2], output) 104 | 105 | del output1, output2 106 | -------------------------------------------------------------------------------- /scripts/generate_pb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/local/env bash 2 | 3 | set -evo pipefail 4 | 5 | PROTO_DIR="proto/v1" 6 | PROTO_FILES=$(find $PROTO_DIR -name "*.proto") 7 | 8 | PYTHON_OUTPUT_DIR="python/cornserve/services/pb" 9 | mkdir -p "$PYTHON_OUTPUT_DIR" 10 | rm "$PYTHON_OUTPUT_DIR"/*pb2.py "$PYTHON_OUTPUT_DIR"/*pb2.pyi "$PYTHON_OUTPUT_DIR"/*pb2_grpc.py || true 11 | for proto in $PROTO_FILES; do 12 | echo "Generating code for $proto in $PYTHON_OUTPUT_DIR" 13 | python -m grpc_tools.protoc \ 14 | -I$PROTO_DIR \ 15 | --python_out=$PYTHON_OUTPUT_DIR \ 16 | --grpc_python_out=$PYTHON_OUTPUT_DIR \ 17 | --pyi_out=$PYTHON_OUTPUT_DIR \ 18 | $proto 19 | done 20 | 21 | # The generated `import common_pb2`, for example, doesn't work. 22 | # We need to change it manually to `from . import common_pb2`. 23 | # `-i.bak` is done for GNU and BSD sed compatibility. 24 | find "$PYTHON_OUTPUT_DIR" -type f -name "*.py" -exec sed -i.bak -e 's/^\(import.*pb2\)/from . \1/g' {} \; 25 | find "$PYTHON_OUTPUT_DIR" -type f -name "*.bak" -exec rm {} \; 26 | --------------------------------------------------------------------------------