├── .python-version ├── .dockerignore ├── .gitattributes ├── .gitignore ├── requirements-dev.txt ├── prompt_templates.py ├── .pylintrc ├── test.sh ├── pyproject.toml ├── cog.yaml ├── requirements.txt ├── tests ├── end_to_end │ └── local │ │ └── test_predict.py └── unit │ └── test_predict.py ├── .github └── workflows │ ├── lint-and-test.yml │ └── build-and-push.yml ├── requirements.lock ├── requirements-dev.lock ├── utils.py ├── README.md ├── train.py └── predict.py /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.9 2 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | *.tar 2 | models/** -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | requirements* linguist-generated=true 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | model_path 2 | .cog 3 | __pycache__ 4 | 5 | models/** 6 | hf-cache/ 7 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest-asyncio 2 | pytest-mock 3 | pylint 4 | black 5 | cog==0.10.0a18 6 | attrs>=20.1,<24 7 | jinja2 -------------------------------------------------------------------------------- /prompt_templates.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-module-docstring, line-too-long 2 | COMPLETION = "{prompt}" 3 | LLAMA_3_INSTRUCT = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 4 | LLAMA_2_INSTRUCT = "[INST] <>\n{system_prompt}\n<>\n\n{prompt} [/INST]" 5 | MISTRAL_INSTRUCT = "[INST] {system_prompt} {prompt} [/INST]" 6 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | disable=C0114, # Missing module docstring 3 | E0611, # No name in module 4 | W0201, # Attribute defined outside init 5 | C0116, # Missing function docstring 6 | W0621, # Redefined outer name 7 | E0401, # Import error 8 | R0903, # Too few public methods 9 | R0913, # Too many arguments 10 | R0914, # Too many local variables 11 | C0115, # Missing class docstring 12 | 13 | [REPORTS] 14 | output-format=colorized -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Mirror of https://huggingface.co/EleutherAI/pythia-70m 4 | COG_WEIGHTS="https://replicate.delivery/czjl/HUTgHv0M6FbnJxzkbe7Ly1fN19tabwYOZTFLuJld3f7MifpLB/model.tar" 5 | 6 | exec cog predict \ 7 | -e COG_WEIGHTS=$COG_WEIGHTS \ 8 | -i prompt="write a python program that prints Hello World!" \ 9 | -i max_new_tokens=512 \ 10 | -i temperature=0.6 \ 11 | -i top_p=0.9 \ 12 | -i top_k=50 \ 13 | -i presence_penalty=0.0 \ 14 | -i frequency_penalty=0.0 \ 15 | -i prompt_template="[INST] {prompt} [/INST] " 16 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "cog-vllm" 3 | version = "0.1.0" 4 | description = "" 5 | authors = [] 6 | dependencies = [ 7 | "cog==0.10.0a10", 8 | "aiohttp[speedups]>=3.9.5", 9 | "scipy>=1.13.1", 10 | "sentencepiece>=0.2.0", 11 | "protobuf>=5.27.0", 12 | "huggingface-hub>=0.23.2", 13 | "httpx>=0.27.0", 14 | "tqdm>=4.66.4", 15 | "torch>=2.3.0", 16 | "jinja2>=3.1.4", 17 | ] 18 | readme = "README.md" 19 | requires-python = "== 3.11.9" 20 | 21 | [build-system] 22 | requires = ["hatchling"] 23 | build-backend = "hatchling.build" 24 | 25 | [tool.rye] 26 | managed = true 27 | dev-dependencies = [] 28 | 29 | [tool.hatch.metadata] 30 | allow-direct-references = true 31 | 32 | [tool.hatch.build.targets.wheel] 33 | packages = ["src/cog_vllm"] 34 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | cuda: "12.1" 7 | 8 | python_version: "3.11.9" 9 | 10 | python_requirements: requirements.txt 11 | 12 | run: 13 | - --mount=type=cache,target=/root/.cache/pip TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" CUDA_HOME=/usr/local/cuda pip install --ignore-installed vllm==0.5.3.post1 14 | - --mount=type=cache,target=/root/.cache/pip pip install cog==0.10.0a18 15 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 16 | - sed -i "s/from vllm.model_executor.layers.quantization.schema import QuantParamSchema/# from vllm.model_executor.layers.quantization.schema import QuantParamSchema/" /root/.pyenv/versions/3.11.9/lib/python3.11/site-packages/vllm/model_executor/model_loader/weight_utils.py 17 | 18 | predict: "predict.py:Predictor" 19 | train: "train.py:train" 20 | 21 | concurrency: 22 | max: 32 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile pyproject.toml -o requirements.txt 3 | aiodns==3.2.0 4 | # via aiohttp 5 | aiohttp==3.9.5 6 | aiosignal==1.3.1 7 | # via aiohttp 8 | anyio==4.4.0 9 | # via httpx 10 | attrs==23.2.0 11 | # via aiohttp 12 | brotli==1.1.0 13 | # via aiohttp 14 | certifi==2024.6.2 15 | # via 16 | # httpcore 17 | # httpx 18 | # requests 19 | cffi==1.16.0 20 | # via pycares 21 | charset-normalizer==3.3.2 22 | # via requests 23 | filelock==3.14.0 24 | # via huggingface-hub 25 | frozenlist==1.4.1 26 | # via 27 | # aiohttp 28 | # aiosignal 29 | fsspec==2024.5.0 30 | # via huggingface-hub 31 | h11==0.14.0 32 | # via httpcore 33 | httpcore==1.0.5 34 | # via httpx 35 | httpx==0.27.0 36 | huggingface-hub==0.23.2 37 | idna==3.7 38 | # via 39 | # anyio 40 | # httpx 41 | # requests 42 | # yarl 43 | multidict==6.0.5 44 | # via 45 | # aiohttp 46 | # yarl 47 | numpy==1.26.4 48 | # via scipy 49 | packaging==24.0 50 | # via huggingface-hub 51 | protobuf==5.27.0 52 | pycares==4.4.0 53 | # via aiodns 54 | pycparser==2.22 55 | # via cffi 56 | pyyaml==6.0.1 57 | # via huggingface-hub 58 | requests==2.32.3 59 | # via huggingface-hub 60 | scipy==1.13.1 61 | sentencepiece==0.2.0 62 | sniffio==1.3.1 63 | # via 64 | # anyio 65 | # httpx 66 | tqdm==4.66.4 67 | # via huggingface-hub 68 | typing-extensions==4.12.1 69 | # via huggingface-hub 70 | urllib3==2.2.1 71 | # via requests 72 | yarl==1.9.4 73 | # via aiohttp 74 | -------------------------------------------------------------------------------- /tests/end_to_end/local/test_predict.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import json 3 | import os 4 | import shutil 5 | 6 | 7 | def test_predict(): 8 | # Mirror of https://huggingface.co/EleutherAI/pythia-70m 9 | # pylint: disable=line-too-long 10 | predictor_config = { 11 | "engine_args": {"enforce_eager": True}, 12 | } 13 | 14 | config_filename = "predictor_config.json" 15 | backup_filename = "predictor_config.json.bak" 16 | 17 | if os.path.exists(config_filename): 18 | shutil.move(config_filename, backup_filename) 19 | 20 | try: 21 | 22 | with open(config_filename, "w", encoding="utf-8") as temp_config: 23 | json.dump(predictor_config, temp_config, indent=4) 24 | weights_url = "https://weights.replicate.delivery/default/internal-testing/EleutherAI/pythia-70m/model.tar" # pylint: disable=line-too-long 25 | 26 | result = subprocess.run( 27 | [ 28 | "cog", 29 | "predict", 30 | "-e", 31 | f"COG_WEIGHTS={weights_url}", 32 | "-i", 33 | "prompt=Hello!", 34 | "-i", 35 | "max_tokens=10", 36 | ], 37 | capture_output=True, 38 | text=True, 39 | check=True, 40 | ) 41 | 42 | finally: 43 | os.remove(config_filename) 44 | if os.path.exists(backup_filename): 45 | shutil.move(backup_filename, config_filename) 46 | 47 | # Check that the cog predict command completed successfully 48 | assert result.returncode == 0, f"Cog predict failed with error: {result.stderr}" 49 | 50 | # Parse the output 51 | output = result.stdout.strip().splitlines() 52 | 53 | # Make assertions based on the expected output 54 | assert isinstance(output, list), "Output is not a list of strings" 55 | assert len(output) > 0, "Output list is empty" 56 | for line in output: 57 | assert isinstance(line, str), f"Output contains a non-string element: {line}" 58 | 59 | # Optionally print the output for debugging 60 | print("Output from cog predict:", output) 61 | -------------------------------------------------------------------------------- /.github/workflows/lint-and-test.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Test 2 | 3 | 4 | on: 5 | pull_request: 6 | workflow_dispatch: 7 | 8 | # Ensure only one workflow instance runs at a time. For branches other than the 9 | # default branch, cancel the pending jobs in the group. For the default branch, 10 | # queue them up. This avoids cancelling jobs that are in the middle of deploying 11 | # to production. 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }} 14 | cancel-in-progress: ${{ github.ref != format('refs/heads/{0}', github.event.repository.default_branch) }} 15 | 16 | jobs: 17 | lint: 18 | name: Lint 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: '3.10' 26 | - name: Create and activate virtual environment 27 | run: | 28 | python -m venv venv 29 | source venv/bin/activate 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install -r requirements-dev.txt 34 | - name: Run pylint 35 | run: | 36 | pylint --recursive y tests/**/*.py 37 | pylint --recursive y ./*.py 38 | 39 | unit-tests: 40 | name: Unit tests 41 | runs-on: ubuntu-latest 42 | steps: 43 | - uses: actions/checkout@v3 44 | - name: Set up Python 45 | uses: actions/setup-python@v4 46 | with: 47 | python-version: '3.10' 48 | - name: Create and activate virtual environment 49 | run: | 50 | python -m venv venv 51 | source venv/bin/activate 52 | - name: Install dependencies 53 | run: | 54 | python -m pip install --upgrade pip 55 | pip install -r requirements-dev.txt 56 | - name: Debug information 57 | run: | 58 | which python 59 | python --version 60 | pip list 61 | python -c "import sys; print(sys.path)" 62 | python -c "import attrs; print(attrs.__file__)" 63 | - name: Run unit tests 64 | run: pytest tests/unit -------------------------------------------------------------------------------- /requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | # with-sources: false 9 | # generate-hashes: false 10 | 11 | -e file:. 12 | aiodns==3.2.0 13 | # via aiohttp 14 | aiohttp==3.9.5 15 | # via cog-vllm 16 | aiosignal==1.3.1 17 | # via aiohttp 18 | anyio==4.4.0 19 | # via httpx 20 | # via starlette 21 | # via watchfiles 22 | attrs==23.2.0 23 | # via aiohttp 24 | # via cog 25 | brotli==1.1.0 26 | # via aiohttp 27 | certifi==2024.6.2 28 | # via httpcore 29 | # via httpx 30 | # via requests 31 | cffi==1.16.0 32 | # via pycares 33 | charset-normalizer==3.3.2 34 | # via requests 35 | click==8.1.7 36 | # via uvicorn 37 | cog==0.10.0a10 38 | # via cog-vllm 39 | fastapi==0.98.0 40 | # via cog 41 | filelock==3.14.0 42 | # via huggingface-hub 43 | # via torch 44 | frozenlist==1.4.1 45 | # via aiohttp 46 | # via aiosignal 47 | fsspec==2024.5.0 48 | # via huggingface-hub 49 | # via torch 50 | h11==0.14.0 51 | # via httpcore 52 | # via uvicorn 53 | h2==4.1.0 54 | # via httpx 55 | hpack==4.0.0 56 | # via h2 57 | httpcore==1.0.5 58 | # via httpx 59 | httptools==0.6.1 60 | # via uvicorn 61 | httpx==0.27.0 62 | # via cog 63 | # via cog-vllm 64 | huggingface-hub==0.23.2 65 | # via cog-vllm 66 | hyperframe==6.0.1 67 | # via h2 68 | idna==3.7 69 | # via anyio 70 | # via httpx 71 | # via requests 72 | # via yarl 73 | jinja2==3.1.4 74 | # via cog-vllm 75 | # via torch 76 | markupsafe==2.1.5 77 | # via jinja2 78 | mpmath==1.3.0 79 | # via sympy 80 | multidict==6.0.5 81 | # via aiohttp 82 | # via yarl 83 | networkx==3.3 84 | # via torch 85 | numpy==1.26.4 86 | # via scipy 87 | packaging==24.0 88 | # via huggingface-hub 89 | protobuf==5.27.0 90 | # via cog-vllm 91 | pycares==4.4.0 92 | # via aiodns 93 | pycparser==2.22 94 | # via cffi 95 | pydantic==1.10.15 96 | # via cog 97 | # via fastapi 98 | python-dotenv==1.0.1 99 | # via uvicorn 100 | pyyaml==6.0.1 101 | # via cog 102 | # via huggingface-hub 103 | # via uvicorn 104 | requests==2.32.3 105 | # via cog 106 | # via huggingface-hub 107 | scipy==1.13.1 108 | # via cog-vllm 109 | sentencepiece==0.2.0 110 | # via cog-vllm 111 | sniffio==1.3.1 112 | # via anyio 113 | # via httpx 114 | starlette==0.27.0 115 | # via fastapi 116 | structlog==24.2.0 117 | # via cog 118 | sympy==1.12.1 119 | # via torch 120 | torch==2.3.0 121 | # via cog-vllm 122 | tqdm==4.66.4 123 | # via cog-vllm 124 | # via huggingface-hub 125 | typing-extensions==4.12.1 126 | # via cog 127 | # via huggingface-hub 128 | # via pydantic 129 | # via torch 130 | urllib3==2.2.1 131 | # via requests 132 | uvicorn==0.30.1 133 | # via cog 134 | uvloop==0.19.0 135 | # via uvicorn 136 | watchfiles==0.22.0 137 | # via uvicorn 138 | websockets==12.0 139 | # via uvicorn 140 | yarl==1.9.4 141 | # via aiohttp 142 | -------------------------------------------------------------------------------- /requirements-dev.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | # with-sources: false 9 | # generate-hashes: false 10 | 11 | -e file:. 12 | aiodns==3.2.0 13 | # via aiohttp 14 | aiohttp==3.9.5 15 | # via cog-vllm 16 | aiosignal==1.3.1 17 | # via aiohttp 18 | anyio==4.4.0 19 | # via httpx 20 | # via starlette 21 | # via watchfiles 22 | attrs==23.2.0 23 | # via aiohttp 24 | # via cog 25 | brotli==1.1.0 26 | # via aiohttp 27 | certifi==2024.6.2 28 | # via httpcore 29 | # via httpx 30 | # via requests 31 | cffi==1.16.0 32 | # via pycares 33 | charset-normalizer==3.3.2 34 | # via requests 35 | click==8.1.7 36 | # via uvicorn 37 | cog==0.10.0a10 38 | # via cog-vllm 39 | fastapi==0.98.0 40 | # via cog 41 | filelock==3.14.0 42 | # via huggingface-hub 43 | # via torch 44 | frozenlist==1.4.1 45 | # via aiohttp 46 | # via aiosignal 47 | fsspec==2024.5.0 48 | # via huggingface-hub 49 | # via torch 50 | h11==0.14.0 51 | # via httpcore 52 | # via uvicorn 53 | h2==4.1.0 54 | # via httpx 55 | hpack==4.0.0 56 | # via h2 57 | httpcore==1.0.5 58 | # via httpx 59 | httptools==0.6.1 60 | # via uvicorn 61 | httpx==0.27.0 62 | # via cog 63 | # via cog-vllm 64 | huggingface-hub==0.23.2 65 | # via cog-vllm 66 | hyperframe==6.0.1 67 | # via h2 68 | idna==3.7 69 | # via anyio 70 | # via httpx 71 | # via requests 72 | # via yarl 73 | jinja2==3.1.4 74 | # via cog-vllm 75 | # via torch 76 | markupsafe==2.1.5 77 | # via jinja2 78 | mpmath==1.3.0 79 | # via sympy 80 | multidict==6.0.5 81 | # via aiohttp 82 | # via yarl 83 | networkx==3.3 84 | # via torch 85 | numpy==1.26.4 86 | # via scipy 87 | packaging==24.0 88 | # via huggingface-hub 89 | protobuf==5.27.0 90 | # via cog-vllm 91 | pycares==4.4.0 92 | # via aiodns 93 | pycparser==2.22 94 | # via cffi 95 | pydantic==1.10.15 96 | # via cog 97 | # via fastapi 98 | python-dotenv==1.0.1 99 | # via uvicorn 100 | pyyaml==6.0.1 101 | # via cog 102 | # via huggingface-hub 103 | # via uvicorn 104 | requests==2.32.3 105 | # via cog 106 | # via huggingface-hub 107 | scipy==1.13.1 108 | # via cog-vllm 109 | sentencepiece==0.2.0 110 | # via cog-vllm 111 | sniffio==1.3.1 112 | # via anyio 113 | # via httpx 114 | starlette==0.27.0 115 | # via fastapi 116 | structlog==24.2.0 117 | # via cog 118 | sympy==1.12.1 119 | # via torch 120 | torch==2.3.0 121 | # via cog-vllm 122 | tqdm==4.66.4 123 | # via cog-vllm 124 | # via huggingface-hub 125 | typing-extensions==4.12.1 126 | # via cog 127 | # via huggingface-hub 128 | # via pydantic 129 | # via torch 130 | urllib3==2.2.1 131 | # via requests 132 | uvicorn==0.30.1 133 | # via cog 134 | uvloop==0.19.0 135 | # via uvicorn 136 | watchfiles==0.22.0 137 | # via uvicorn 138 | websockets==12.0 139 | # via uvicorn 140 | yarl==1.9.4 141 | # via aiohttp 142 | -------------------------------------------------------------------------------- /.github/workflows/build-and-push.yml: -------------------------------------------------------------------------------- 1 | name: Build and Push to Replicate 2 | 3 | 4 | on: 5 | workflow_dispatch: 6 | inputs: 7 | git_branch: 8 | description: 'Enter the git branch name to check out and push' 9 | required: true 10 | default: 'main' 11 | 12 | # Ensure only one workflow instance runs at a time. For branches other than the 13 | # default branch, cancel the pending jobs in the group. For the default branch, 14 | # queue them up. This avoids cancelling jobs that are in the middle of deploying 15 | # to production. 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.ref }} 18 | cancel-in-progress: ${{ github.ref != format('refs/heads/{0}', github.event.repository.default_branch) }} 19 | 20 | jobs: 21 | lint: 22 | name: Lint 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v3 26 | - name: Set up Python 27 | uses: actions/setup-python@v4 28 | with: 29 | python-version: '3.10' 30 | - name: Create and activate virtual environment 31 | run: | 32 | python -m venv venv 33 | source venv/bin/activate 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install -r requirements-dev.txt 38 | - name: Run pylint 39 | run: | 40 | pylint --recursive y tests/**/*.py 41 | pylint --recursive y ./*.py 42 | 43 | unit-tests: 44 | name: Unit tests 45 | runs-on: ubuntu-latest 46 | steps: 47 | - uses: actions/checkout@v3 48 | - name: Set up Python 49 | uses: actions/setup-python@v4 50 | with: 51 | python-version: '3.10' 52 | - name: Create and activate virtual environment 53 | run: | 54 | python -m venv venv 55 | source venv/bin/activate 56 | - name: Install dependencies 57 | run: | 58 | python -m pip install --upgrade pip 59 | pip install -r requirements-dev.txt 60 | - name: Run unit tests 61 | run: pytest tests/unit 62 | 63 | build-and-push: 64 | name: Build and push 65 | needs: [lint, unit-tests] 66 | permissions: 67 | contents: 'read' 68 | id-token: 'write' 69 | runs-on: ubuntu-latest-16-cores 70 | steps: 71 | - name: Checkout 72 | uses: actions/checkout@v4 73 | with: 74 | ref: ${{ inputs.git_branch }} 75 | 76 | - name: Setup Cog 77 | uses: replicate/setup-cog@v2 78 | with: 79 | token: ${{ secrets.REPLICATE_API_TOKEN }} 80 | 81 | - name: Install Cog 82 | run: | 83 | COG_URL="https://github.com/replicate/cog/releases/download/v0.10.0-alpha20/cog_$(uname -s)_$(uname -m)" 84 | sudo curl -o /usr/local/bin/cog -L "$COG_URL" 85 | sudo chmod +x /usr/local/bin/cog 86 | 87 | - name: Push to Replicate 88 | run: | 89 | cog push r8.im/replicate/vllm 90 | 91 | - name: Setup Cog 92 | uses: replicate/setup-cog@v2 93 | with: 94 | token: ${{ secrets.REPLICATE_API_TOKEN_ORG_REPLICATE_INTERNAL }} 95 | 96 | - name: Install Cog 97 | run: | 98 | COG_URL="https://github.com/replicate/cog/releases/download/v0.10.0-alpha20/cog_$(uname -s)_$(uname -m)" 99 | sudo curl -o /usr/local/bin/cog -L "$COG_URL" 100 | sudo chmod +x /usr/local/bin/cog 101 | 102 | - name: Push to replicate-internal 103 | run: | 104 | cog push r8.im/replicate-internal/vllm 105 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | import warnings 5 | from urllib.parse import urlparse 6 | from pathlib import Path 7 | import asyncio 8 | import shutil 9 | 10 | 11 | async def resolve_model_path(url_or_local_path: str) -> str: 12 | """ 13 | Resolves the model path, downloading if necessary. 14 | 15 | Args: 16 | url_or_local_path (str): URL to the tarball or local path to a 17 | directory containing the model artifacts. 18 | 19 | Returns: 20 | str: Path to the directory containing the model artifacts. 21 | """ 22 | 23 | parsed_url = urlparse(url_or_local_path) 24 | if parsed_url.scheme in ["http", "https"]: 25 | return await download_tarball(url_or_local_path) 26 | 27 | if parsed_url.scheme in ["file", ""]: 28 | if not os.path.exists(parsed_url.path): 29 | raise ValueError( 30 | f"E1000: The provided local path '{parsed_url.path}' does not exist." 31 | ) 32 | if not os.listdir(parsed_url.path): 33 | raise ValueError( 34 | f"E1000: The provided local path '{parsed_url.path}' is empty." 35 | ) 36 | 37 | warnings.warn( 38 | "Using local model artifacts for development is okay, but not optimal for production. " 39 | "To minimize boot time, store model assets externally on Replicate." 40 | ) 41 | return url_or_local_path 42 | raise ValueError(f"E1000: Unsupported model path scheme: {parsed_url.scheme}") 43 | 44 | 45 | async def maybe_download_tarball_with_pget( 46 | url: str, 47 | dest: str, 48 | ): 49 | """ 50 | Checks for existing model weights in a local volume, downloads if necessary, 51 | and sets up symlinks. 52 | 53 | This function first checks if a local volume (/weights) exists and can be used. If so, it uses 54 | this as the primary destination. If the weights already exist in the local volume or the 55 | specified destination, no download occurs. Otherwise, it downloads the tarball from the 56 | provided URL using pget and extracts it. 57 | 58 | Args: 59 | url (str): URL to the model tarball. 60 | dest (str): Intended destination path for the model weights. 61 | 62 | Returns: 63 | str: Path to the directory containing the model weights, which may be either 64 | the original destination or a symlink to the local volume. 65 | 66 | Note: 67 | - If weights are in the local volume, a symlink is created to `dest`. 68 | - If weights are already present in either location, no download occurs. 69 | - The function prioritizes using the local volume (/weights) if available. 70 | """ 71 | try: 72 | Path("/weights").mkdir(exist_ok=True) 73 | first_dest = "/weights/vllm" 74 | except PermissionError: 75 | print("/weights doesn't exist, and we couldn't create it") 76 | first_dest = dest 77 | 78 | # if dest exists and is not empty, return 79 | if os.path.exists(first_dest) and os.listdir(first_dest): 80 | print(f"Files already present in `{first_dest}`, nothing will be downloaded.") 81 | if first_dest != dest: 82 | try: 83 | if os.path.islink(dest): 84 | os.unlink(dest) 85 | os.symlink(first_dest, dest) 86 | except FileExistsError: 87 | print(f"Ignoring existing file at {dest}") 88 | return dest 89 | 90 | # if dest exists but is empty, remove it so we can pull with pget 91 | if os.path.exists(first_dest): 92 | shutil.rmtree(first_dest) 93 | 94 | print("Downloading model assets...") 95 | start_time = time.time() 96 | command = ["pget", url, first_dest, "-x"] 97 | # subprocess.check_call(command, close_fds=True) 98 | 99 | process = await asyncio.create_subprocess_exec(*command, close_fds=True) 100 | await process.wait() 101 | if process.returncode != 0: 102 | raise subprocess.CalledProcessError(process.returncode, command) 103 | 104 | print(f"Downloaded model assets in {time.time() - start_time:.2f}s") 105 | if first_dest != dest: 106 | if os.path.islink(dest): 107 | os.unlink(dest) 108 | os.symlink(first_dest, dest) 109 | 110 | return dest 111 | 112 | 113 | async def download_tarball(url: str) -> str: 114 | """ 115 | Downloads a tarball from a URL and extracts it. 116 | 117 | Args: 118 | url (str): URL to the tarball. 119 | 120 | Returns: 121 | str: Path to the directory where the tarball was extracted. 122 | """ 123 | filename = os.path.splitext(os.path.basename(url))[0] 124 | dest = os.path.join(os.getcwd(), filename) 125 | return await maybe_download_tarball_with_pget(url, dest) 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cog-vLLM: Run vLLM on Replicate 2 | 3 | [Cog](https://github.com/replicate/cog) 4 | is an open-source tool that lets you package machine learning models 5 | in a standard, production-ready container. 6 | vLLM is a fast and easy-to-use library for LLM inference and serving. 7 | 8 | You can deploy your packaged model to your own infrastructure, 9 | or to [Replicate]. 10 | 11 | ## Highlights 12 | 13 | * 🚀 **Run vLLM in the cloud with an API**. 14 | Deploy any [vLLM-supported language model] at scale on Replicate. 15 | 16 | * 🏭 **Support multiple concurrent requests**. 17 | Continuous batching works out of the box. 18 | 19 | * 🐢 **Open Source, all the way down**. 20 | Look inside, take it apart, make it do exactly what you need. 21 | 22 | ## Quickstart 23 | 24 | Go to [replicate.com/replicate/vllm](https://replicate.com/replicate/vllm) 25 | and create a new vLLM model from a [supported Hugging Face repo][vLLM-supported language model], 26 | such as [google/gemma-2b](https://huggingface.co/google/gemma-2b) 27 | 28 | > [!IMPORTANT] 29 | > Gated models require a [Hugging Face API token](https://huggingface.co/settings/tokens), 30 | > which you can set in the `hf_token` field of the model creation form. 31 | 32 | Create a new vLLM model on Replicate 33 | 34 | Replicate downloads the model files, packages them into a `.tar` archive, 35 | and pushes a new version of your model that's ready to use. 36 | 37 | Trained vLLM model on Replicate 38 | 39 | From here, you can either use your model as-is, 40 | or customize it and push up your changes. 41 | 42 | ## Local Development 43 | 44 | If you're on a machine or VM with a GPU, 45 | you can try out changes before pushing them to Replicate. 46 | 47 | Start by [installing or upgrading Cog](https://cog.run/#install). 48 | You'll need Cog [v0.10.0-alpha11](https://github.com/replicate/cog/releases/tag/v0.10.0-alpha11): 49 | 50 | ```console 51 | $ sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/download/v0.10.0-alpha11/cog_$(uname -s)_$(uname -m)" 52 | $ sudo chmod +x /usr/local/bin/cog 53 | ``` 54 | 55 | Then clone this repository: 56 | 57 | ```console 58 | $ git clone https://github.com/replicate/cog-vllm 59 | $ cd cog-vllm 60 | ``` 61 | 62 | Go to the [Replicate dashboard](https://replicate.com/trainings) and 63 | navigate to the training for your vLLM model. 64 | From that page, copy the weights URL from the Download weights button. 65 | 66 | Copy weights URL from Replicate training 67 | 68 | Set the `COG_WEIGHTS` environment variable with that copied value: 69 | 70 | ```console 71 | $ export COG_WEIGHTS="..." 72 | ``` 73 | 74 | Now, make your first prediction against the model locally: 75 | 76 | ```console 77 | $ cog predict -e "COG_WEIGHTS=$COG_WEIGHTS" \ 78 | -i prompt="Hello!" 79 | ``` 80 | 81 | The first time you run this command, 82 | Cog downloads the model weights and save them to the `models` subdirectory. 83 | 84 | To make multiple predictions, 85 | start up the HTTP server and send it `POST /predictions` requests. 86 | 87 | ```console 88 | # Start the HTTP server 89 | $ cog run -p 5000 -e "COG_WEIGHTS=$COG_WEIGHTS" python -m cog.server.http 90 | 91 | # In a different terminal session, send requests to the server 92 | $ curl http://localhost:5000/predictions -X POST \ 93 | -H 'Content-Type: application/json' \ 94 | -d '{"input": {"prompt": "Hello!"}}' 95 | ``` 96 | 97 | When you're finished working, 98 | you can push your changes to Replicate. 99 | 100 | Grab your token from [replicate.com/account](https://replicate.com/account) 101 | and set it as an environment variable: 102 | 103 | ```shell 104 | export REPLICATE_API_TOKEN= 105 | ``` 106 | 107 | ```console 108 | $ echo $REPLICATE_API_TOKEN | cog login --token-stdin 109 | $ cog push r8.im// 110 | --> ... 111 | --> Pushing image 'r8.im/...' 112 | ``` 113 | 114 | After you push your model, you can try running it on Replicate. 115 | 116 | Install the [Replicate Python SDK][replicate-python]: 117 | 118 | ```console 119 | $ pip install replicate 120 | ``` 121 | 122 | Create a prediction and stream its output: 123 | 124 | ```python 125 | import replicate 126 | 127 | model = replicate.models.get("/") 128 | prediction = replicate.predictions.create( 129 | version=model.latest_version, 130 | input={ "prompt": "Hello" }, 131 | stream=True 132 | ) 133 | 134 | for event in prediction.stream(): 135 | print(str(event), end="") 136 | ``` 137 | 138 | [Replicate]: https://replicate.com 139 | [vLLM-supported language model]: https://docs.vllm.ai/en/latest/models/supported_models.html 140 | [replicate-python]: https://github.com/replicate/replicate-python 141 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import tarfile 4 | import time 5 | from collections import namedtuple 6 | from dataclasses import asdict 7 | 8 | import httpx 9 | import tqdm 10 | from cog import BaseModel, Input, Path, Secret 11 | from huggingface_hub import ( 12 | HfApi, 13 | get_hf_file_metadata, 14 | hf_hub_url, 15 | ) 16 | from huggingface_hub._login import _login as hf_login 17 | from huggingface_hub.utils import filter_repo_objects 18 | 19 | from predict import PredictorConfig 20 | 21 | 22 | class TrainingOutput(BaseModel): 23 | weights: Path 24 | 25 | 26 | def train( 27 | hf_model_id: str = Input( 28 | description=""" 29 | Hugging Face model identifier 30 | (e.g. NousResearch/Hermes-2-Theta-Llama-3-8B). 31 | """, 32 | ), 33 | hf_model_sha: str = Input( 34 | description=""" 35 | The version of the model. 36 | If unspecified, the latest version is used. 37 | """, 38 | default=None, 39 | ), 40 | hf_token: Secret = Input( 41 | description=""" 42 | Hugging Face API token. 43 | Get your token at https://huggingface.co/settings/tokens 44 | """, 45 | default=None, 46 | ), 47 | allow_patterns: str = Input( 48 | description=""" 49 | Patterns constituting the allowlist. 50 | If provided, item paths must match at least one pattern from the allowlist. 51 | (e.g. "*.safetensors"). 52 | """, 53 | default=None, 54 | ), 55 | ignore_patterns: str = Input( 56 | description=""" 57 | Patterns constituting the denylist. 58 | If provided, item paths must not match any patterns from the denylist. 59 | (e.g. "*.gguf"). 60 | """, 61 | default="*.gguf", 62 | ), 63 | prompt_template: str = Input( 64 | description=""" 65 | Prompt template. This is a Jinja2 template that overrides the 66 | HuggingFace tokenizer configuration. If this is set to None and nothing 67 | is configured on HuggingFace, no formatting is applied. 68 | To override HuggingFace configuration, set it to the string 69 | `{{messages[0]['content']}}`.""", 70 | default=None, 71 | ), 72 | ) -> TrainingOutput: 73 | if hf_token is not None and isinstance(hf_token, Secret): 74 | print("Logging in to Hugging Face Hub...") 75 | hf_token = hf_token.get_secret_value().strip() 76 | hf_login(token=hf_token, add_to_git_credential=False) 77 | else: 78 | print("No HuggingFace token provided.") 79 | 80 | api = HfApi() 81 | 82 | # Fetch the model info 83 | model = api.model_info( 84 | hf_model_id, revision=hf_model_sha, token=hf_token, files_metadata=True 85 | ) 86 | print(f"Using model {model.id} with SHA {model.sha}") 87 | 88 | # Determine which files to download 89 | files = list( 90 | filter_repo_objects( 91 | items=[f.rfilename for f in model.siblings], 92 | allow_patterns=allow_patterns, 93 | ignore_patterns=ignore_patterns, 94 | ) 95 | ) 96 | if len(files) == 0: 97 | raise ValueError("No files to download") 98 | 99 | Entry = namedtuple("Entry", ["filename", "url", "metadata"]) 100 | entries = [ 101 | Entry( 102 | filename=x, 103 | url=hf_hub_url(repo_id=hf_model_id, filename=x), 104 | metadata=get_hf_file_metadata( 105 | hf_hub_url(repo_id=hf_model_id, filename=x), token=hf_token 106 | ), 107 | ) 108 | for x in files 109 | ] 110 | 111 | config = PredictorConfig(prompt_template=prompt_template) 112 | 113 | start = time.time() 114 | print(f"Downloading {len(files)} files...") 115 | 116 | # Download the files and write them to a tar file 117 | weights = Path("model.tar") 118 | with tarfile.open(name=str(weights), mode="w:") as tar: 119 | # Add predictor_config.json 120 | 121 | predictor_config_data = json.dumps(asdict(config)).encode("utf-8") 122 | tar_info = tarfile.TarInfo("predictor_config.json") 123 | tar_info.mtime = int(time.time()) 124 | tar_info.size = len(predictor_config_data) 125 | tar.addfile(tar_info, fileobj=io.BytesIO(predictor_config_data)) 126 | 127 | with tqdm.tqdm( 128 | total=sum(entry.metadata.size for entry in entries), 129 | unit="B", 130 | unit_divisor=1024, 131 | unit_scale=True, 132 | mininterval=1, 133 | ) as pbar: 134 | headers = {"Authorization": f"Bearer {hf_token}"} 135 | with httpx.Client( 136 | headers=headers, follow_redirects=True, timeout=None 137 | ) as client: 138 | for n, entry in enumerate(entries, start=1): 139 | pbar.update(0) 140 | pbar.set_postfix( 141 | n=f"{n}/{len(entries)}", 142 | file=entry.filename, 143 | refresh=True, 144 | ) 145 | 146 | with client.stream("GET", entry.url) as response: 147 | response.raise_for_status() 148 | 149 | with io.BytesIO() as buffer: 150 | for chunk in response.iter_bytes(chunk_size=32 * 1024): 151 | buffer.write(chunk) 152 | 153 | pbar.update(len(chunk)) 154 | pbar.set_postfix( 155 | n=f"{n}/{len(entries)}", 156 | file=entry.filename, 157 | refresh=False, 158 | ) 159 | 160 | buffer.seek(0) 161 | 162 | tar_info = tarfile.TarInfo(entry.filename) 163 | tar_info.mtime = int(time.time()) 164 | tar_info.size = entry.metadata.size 165 | tar.addfile(tar_info, fileobj=buffer) 166 | 167 | print(f"Downloaded {len(files)} files in {time.time() - start:.2f} seconds") 168 | 169 | return TrainingOutput(weights=weights) 170 | -------------------------------------------------------------------------------- /tests/unit/test_predict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | import json 4 | from unittest.mock import AsyncMock, MagicMock, patch, mock_open, Mock 5 | import pytest 6 | 7 | # Add the project root directory to the Python path 8 | project_root = Path(__file__).resolve().parent.parent.parent 9 | sys.path.insert(0, str(project_root)) 10 | 11 | # Mock required modules 12 | mock_torch = MagicMock() 13 | mock_torch.cuda.device_count.return_value = 1 # Set a default return value 14 | 15 | sys.modules["torch"] = mock_torch 16 | sys.modules["vllm"] = MagicMock() 17 | sys.modules["vllm.engine"] = MagicMock() 18 | sys.modules["vllm.engine.arg_utils"] = MagicMock() 19 | sys.modules["vllm.sampling_params"] = MagicMock() 20 | 21 | 22 | from predict import Predictor, PredictorConfig, UserError # pylint: disable=import-error, wrong-import-position 23 | 24 | 25 | class MockInput: # pylint: disable=too-few-public-methods 26 | """ 27 | Use this to mock default inputs for the Predictor class. 28 | """ 29 | def __init__(self, default=None, **kwargs): 30 | self.default = default 31 | self.__dict__.update(kwargs) 32 | 33 | def __bool__(self): 34 | return bool(self.default) 35 | 36 | 37 | sys.modules["cog"] = Mock() 38 | sys.modules["cog"].Input = MockInput 39 | 40 | 41 | @pytest.fixture 42 | def mock_dependencies(): 43 | with patch("predict.AsyncLLMEngine") as mock_engine_class, patch( 44 | "predict.AsyncEngineArgs" 45 | ) as mock_engine_args, patch( 46 | "predict.resolve_model_path" 47 | ) as mock_resolve_model_path, patch( 48 | "predict.torch", mock_torch 49 | ): # Explicitly patch torch in predict.py 50 | 51 | mock_engine = AsyncMock() 52 | mock_engine_class.from_engine_args.return_value = mock_engine 53 | mock_resolve_model_path.return_value = "/path/to/weights" 54 | 55 | yield { 56 | "engine": mock_engine, 57 | "engine_class": mock_engine_class, 58 | "engine_args": mock_engine_args, 59 | "resolve_model_path": mock_resolve_model_path, 60 | "torch": mock_torch, 61 | } 62 | 63 | 64 | @pytest.fixture 65 | def mock_predictor_config(): 66 | return { 67 | "prompt_template": "Test template: {prompt}", 68 | "engine_args": {"dtype": "float16", "tensor_parallel_size": 2}, 69 | } 70 | 71 | 72 | @pytest.mark.asyncio 73 | async def test_setup_with_predictor_config(mock_dependencies, mock_predictor_config): 74 | with patch("builtins.open", mock_open(read_data=json.dumps(mock_predictor_config))): 75 | with patch("os.path.exists", return_value=True): 76 | with patch.object(Predictor, 'predict') as mock_predict: 77 | # Create a mock async generator 78 | async def mock_generator(): 79 | yield "test" 80 | mock_predict.return_value = mock_generator() 81 | 82 | predictor = Predictor() 83 | await predictor.setup("dummy_weights") 84 | 85 | 86 | 87 | assert isinstance(predictor.config, PredictorConfig) 88 | assert predictor.config.prompt_template == mock_predictor_config["prompt_template"] 89 | mock_predict.assert_called_once() 90 | assert hasattr(predictor, 'prompt_template') 91 | 92 | mock_dependencies["engine_args"].assert_called_once_with( 93 | model="/path/to/weights", dtype="float16", tensor_parallel_size=2 94 | ) 95 | 96 | 97 | @pytest.mark.asyncio 98 | async def test_setup_without_predictor_config(mock_dependencies): 99 | with patch("os.path.exists", return_value=False): 100 | with patch.object(Predictor, 'predict') as mock_predict: 101 | # Create a mock async generator 102 | async def mock_generator(): 103 | yield "test" 104 | mock_predict.return_value = mock_generator() 105 | 106 | predictor = Predictor() 107 | await predictor.setup("dummy_weights") 108 | 109 | 110 | assert isinstance(predictor.config, PredictorConfig) 111 | assert predictor.config.prompt_template is None 112 | assert predictor.config.engine_args == {} 113 | mock_predict.assert_called_once() 114 | assert hasattr(predictor, 'prompt_template') 115 | 116 | 117 | mock_dependencies["engine_args"].assert_called_once_with( 118 | model="/path/to/weights", dtype="auto", tensor_parallel_size=1 119 | ) 120 | 121 | 122 | @pytest.mark.asyncio 123 | async def test_setup_with_invalid_predictor_config(): 124 | invalid_config = { 125 | "prompt_template": 123, # Should be a string 126 | "engine_args": "not a dict", # Should be a dictionary 127 | } 128 | with patch("builtins.open", mock_open(read_data=json.dumps(invalid_config))): 129 | with patch("os.path.exists", return_value=True): 130 | with patch("predict.resolve_model_path", return_value="dummy_weights"): 131 | predictor = Predictor() 132 | with pytest.raises(UserError) as exc_info: 133 | await predictor.setup("dummy_weights") 134 | assert "E1202 InvalidPredictorConfig:" in str(exc_info.value) 135 | 136 | @pytest.mark.asyncio 137 | async def test_predict(mock_dependencies): 138 | 139 | with patch('predict.cog.emit_metric') as mock_emit_metric: 140 | class MockOutput: 141 | def __init__(self, text): 142 | self.text = text 143 | self.token_ids = [4, 5, 6] # Generated tokens 144 | 145 | class MockResult: 146 | def __init__(self, text): 147 | self.outputs = [MockOutput(text)] 148 | self.prompt_token_ids = [1, 2, 3] # Input tokens 149 | 150 | 151 | # Define an async generator function 152 | async def mock_generate(*args, **kwargs): # pylint: disable=unused-argument 153 | yield MockResult("Generated text") 154 | 155 | mock_dependencies["engine"].generate = mock_generate 156 | 157 | predictor = Predictor() 158 | predictor.log = MagicMock() 159 | with patch.object(Predictor, 'setup') as mock_setup: 160 | def setup_side_effect(*args): # pylint: disable=unused-argument 161 | predictor.engine = mock_dependencies["engine"] 162 | predictor.prompt_template = None 163 | predictor.config = PredictorConfig() 164 | predictor._testing = False # pylint: disable=protected-access 165 | mock_setup.side_effect = setup_side_effect 166 | await predictor.setup("dummy_weights") 167 | 168 | # Mock the tokenizer 169 | predictor.tokenizer = MagicMock() 170 | predictor.tokenizer.chat_template = None 171 | predictor.tokenizer.eos_token_id = 0 172 | 173 | # Call the predict method 174 | result = predictor.predict( 175 | prompt="Test prompt", prompt_template=MockInput(default=None) 176 | ) 177 | 178 | # Consume the async generator 179 | texts = [] 180 | async for item in result: 181 | texts.append(item) 182 | 183 | assert texts == ["Generated text"] 184 | # Assert that emit_metric was called with the expected arguments 185 | mock_emit_metric.assert_any_call("input_token_count", 3) 186 | mock_emit_metric.assert_any_call("output_token_count", 3) 187 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-module-docstring, no-name-in-module, attribute-defined-outside-init 2 | import json 3 | import os 4 | import time 5 | from typing import Optional, Dict 6 | from uuid import uuid4 7 | from dataclasses import dataclass, field 8 | from pprint import pprint 9 | import inspect 10 | import random 11 | import jinja2 12 | import torch # pylint: disable=import-error 13 | import cog # pylint: disable=import-error 14 | from cog import BasePredictor, ConcatenateIterator, Input 15 | from vllm import AsyncLLMEngine 16 | from vllm.engine.arg_utils import AsyncEngineArgs # pylint: disable=import-error 17 | from vllm.sampling_params import SamplingParams # pylint: disable=import-error 18 | 19 | import prompt_templates 20 | from utils import resolve_model_path 21 | 22 | PROMPT_TEMPLATE = prompt_templates.COMPLETION # Change this for instruct models 23 | 24 | SYSTEM_PROMPT = "You are a helpful assistant." 25 | 26 | 27 | @dataclass 28 | class PredictorConfig: 29 | """ 30 | PredictorConfig is a configuration class for the Predictor. 31 | 32 | Attributes: 33 | prompt_template (Optional[str]): A template to format the prompt with. If not provided, 34 | the default prompt template will be used. 35 | engine_args (Optional[Dict]): A dictionary of engine arguments. If not provided, 36 | an empty dictionary will be used. 37 | """ 38 | 39 | prompt_template: Optional[str] = None 40 | engine_args: Optional[Dict] = field(default_factory=dict) 41 | 42 | def __post_init__(self): 43 | if self.engine_args is None: 44 | self.engine_args = {} 45 | if not isinstance(self.engine_args, dict): 46 | raise UserError( 47 | "E1202 InvalidPredictorConfig: engine_args must be " 48 | "a valid JSON object that maps to a dictionary." 49 | ) 50 | 51 | 52 | # pylint: disable=missing-class-docstring 53 | class UserError(Exception): 54 | pass 55 | 56 | 57 | # pylint: disable=missing-class-docstring 58 | class VLLMError(Exception): 59 | pass 60 | 61 | 62 | def format_prompt( 63 | prompt: str, prompt_template: str, system_prompt: Optional[str] 64 | ) -> str: 65 | """ 66 | Formats the given prompt using the provided prompt template and system prompt. 67 | 68 | Args: 69 | prompt (str): The user-provided prompt to be formatted. 70 | prompt_template (str): The template string that includes placeholders for the prompt 71 | and, optionally, system prompt. Must include {prompt}. 72 | system_prompt (Optional[str]): An optional system prompt to be included in the 73 | formatted prompt. 74 | 75 | Returns: 76 | str: The formatted prompt string. 77 | 78 | Raises: 79 | UserError: If the prompt template does not include the '{prompt}' placeholder or if 80 | there is an error in formatting. 81 | """ 82 | if not prompt_template: 83 | prompt_template = "{prompt}" 84 | if prompt and "{prompt}" not in prompt_template: 85 | raise UserError( 86 | "E1003 BadPromptTemplate: You have submitted both a prompt and a " 87 | "prompt template that doesn't include '{prompt}'. Your prompt would " 88 | "not be used. If don't want to use formatting, use your full prompt " 89 | "for the prompt argument and set prompt_template to '{prompt}'." 90 | ) 91 | try: 92 | return prompt_template.format(system_prompt=system_prompt or "", prompt=prompt) 93 | except (ValueError, KeyError, IndexError) as e: 94 | # sometimes people put the prompt in prompt_template 95 | if len(prompt_template) > len(prompt): 96 | raise UserError( 97 | "E1004 PromptTemplateError: Prompt template must be a valid " 98 | "python format spec. Did you submit your prompt as " 99 | "`prompt_template` instead of `prompt`? If you want finer " 100 | 'control over templating, set prompt_template to `"{prompt}"` ' 101 | "to disable formatting. You can't put JSON in prompt_template, " 102 | "because braces will be parsed as a python format string. " 103 | f"Detail: {repr(e)}" 104 | ) from e 105 | # most common case is "unmatched '{' in format spec", 106 | # but IndexError/KeyError and other formatting errors can happen 107 | # str(KeyError) is only the missing key which can be confusing 108 | raise UserError( 109 | f"E1004 PromptTemplateError: Prompt template must be a valid " 110 | f"python format spec: {repr(e)}" 111 | ) from e 112 | 113 | 114 | # pylint: disable=missing-class-docstring 115 | class Predictor(BasePredictor): 116 | async def setup( 117 | self, weights: str 118 | ): # pylint: disable=invalid-overridden-method, signature-differs 119 | if not weights: 120 | raise ValueError( 121 | "Weights must be provided. " 122 | "Set COG_WEIGHTS environment variable to " 123 | "a URL to a tarball containing the weights file " 124 | "or a path to the weights file." 125 | ) 126 | 127 | weights = await resolve_model_path(str(weights)) 128 | self.config = self.load_config(weights) 129 | 130 | engine_args = self.config.engine_args or {} 131 | engine_args["model"] = weights 132 | if "dtype" not in engine_args: 133 | engine_args["dtype"] = "auto" 134 | if "tensor_parallel_size" not in engine_args: 135 | engine_args["tensor_parallel_size"] = max(torch.cuda.device_count(), 1) 136 | 137 | engine_args = AsyncEngineArgs(**engine_args) 138 | 139 | try: 140 | # pylint: disable=attribute-defined-outside-init 141 | self.engine = AsyncLLMEngine.from_engine_args( 142 | engine_args 143 | ) # pylint: disable=attribute-defined-outside-init 144 | except TypeError as e: 145 | print(f"E1201 UnexpectedEngineArg: {e}") 146 | raise 147 | except Exception as e: 148 | print(f"E1200 VLLMUnknownError: {e}") 149 | raise 150 | 151 | # pylint: disable=attribute-defined-outside-init 152 | self.tokenizer = ( 153 | self.engine.engine.tokenizer.tokenizer 154 | if hasattr(self.engine.engine.tokenizer, "tokenizer") 155 | else self.engine.engine.tokenizer 156 | ) 157 | 158 | if self.config.prompt_template: 159 | print( 160 | f"Using prompt template from `predictor_config.json`: {self.config.prompt_template}" 161 | ) 162 | self.tokenizer.chat_template = self.config.prompt_template 163 | self.prompt_template = None 164 | 165 | elif self.tokenizer.chat_template: 166 | print( 167 | f"Using prompt template from `tokenizer`: {self.tokenizer.chat_template}" 168 | ) 169 | self.prompt_template = None 170 | else: 171 | print( 172 | "No prompt template specified in `predictor_config.json` or " 173 | f"`tokenizer`, defaulting to: {PROMPT_TEMPLATE}" 174 | ) 175 | self.tokenizer.chat_template = None 176 | self.prompt_template = PROMPT_TEMPLATE 177 | 178 | self._testing = True 179 | generator = self.predict( 180 | **dict(self._defaults, **{"max_tokens": 3, "prompt": "hi"}) 181 | ) 182 | test_output = "".join([tok async for tok in generator]) 183 | print("Test prediction output:", test_output) 184 | self._testing = False 185 | 186 | async def predict( # pylint: disable=invalid-overridden-method, arguments-differ, too-many-arguments, too-many-locals 187 | self, 188 | prompt: str = Input(description="Prompt", default=""), 189 | system_prompt: str = Input( 190 | description="System prompt to send to the model. This is prepended to " 191 | "the prompt and helps guide system behavior. Ignored for non-chat models.", 192 | default="You are a helpful assistant.", 193 | ), 194 | min_tokens: int = Input( 195 | description="The minimum number of tokens the model should generate as output.", 196 | default=0, 197 | ), 198 | max_tokens: int = Input( 199 | description="The maximum number of tokens the model should generate as output.", 200 | default=512, 201 | ), 202 | temperature: float = Input( 203 | description="The value used to modulate the next token probabilities.", 204 | default=0.6, 205 | ), 206 | top_p: float = Input( 207 | description="A probability threshold for generating the output. If < 1.0, only keep " 208 | "the top tokens with cumulative probability >= top_p (nucleus filtering). " 209 | "Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751).", 210 | default=0.9, 211 | ), 212 | top_k: int = Input( 213 | description="The number of highest probability tokens to consider for generating " 214 | "the output. If > 0, only keep the top k tokens with highest probability " 215 | "(top-k filtering).", 216 | default=50, 217 | ), 218 | presence_penalty: float = Input(description="Presence penalty", default=0.0), 219 | frequency_penalty: float = Input(description="Frequency penalty", default=0.0), 220 | stop_sequences: str = Input( 221 | description="A comma-separated list of sequences to stop generation at. " 222 | "For example, ',' will stop generation at the first instance of " 223 | "'end' or ''.", 224 | default=None, 225 | ), 226 | prompt_template: str = Input( 227 | description="A template to format the prompt with. If not provided, " 228 | "the default prompt template will be used.", 229 | default=None, 230 | ), 231 | seed: int = Input( 232 | description="Random seed. Leave blank to randomize the seed.", 233 | default=None, 234 | ), 235 | ) -> ConcatenateIterator[str]: 236 | start = time.time() 237 | if not seed: 238 | seed = int(random.randint(0, 100000)) 239 | 240 | if prompt_template or self.prompt_template: 241 | prompt_template = prompt_template or self.prompt_template 242 | prompt = format_prompt( 243 | prompt=prompt, 244 | prompt_template=prompt_template, 245 | system_prompt=system_prompt, 246 | ) 247 | 248 | elif self.tokenizer.chat_template: 249 | system_prompt = "" if system_prompt is None else system_prompt 250 | try: 251 | messages = [ 252 | {"role": "system", "content": system_prompt}, 253 | {"role": "user", "content": prompt}, 254 | ] 255 | prompt = self.tokenizer.apply_chat_template( 256 | messages, tokenize=False, add_generation_prompt=True 257 | ) 258 | except jinja2.exceptions.TemplateError: 259 | messages = [ 260 | {"role": "user", "content": "\n\n".join([system_prompt, prompt])} 261 | ] 262 | prompt = self.tokenizer.apply_chat_template( 263 | messages, tokenize=False, add_generation_prompt=True 264 | ) 265 | elif system_prompt: 266 | # pylint: disable=no-member 267 | self.log( 268 | "Warning: ignoring system prompt because no chat template was configured" 269 | ) 270 | 271 | sampling_params = SamplingParams( 272 | n=1, 273 | top_k=(-1 if (top_k or 0) == 0 else top_k), 274 | top_p=top_p, 275 | temperature=temperature, 276 | min_tokens=min_tokens, 277 | max_tokens=max_tokens, 278 | stop_token_ids=[self.tokenizer.eos_token_id], 279 | frequency_penalty=frequency_penalty, 280 | presence_penalty=presence_penalty, 281 | use_beam_search=False, 282 | seed=seed, 283 | ) 284 | if isinstance(stop_sequences, str) and stop_sequences: 285 | sampling_params.stop = stop_sequences.split(",") 286 | else: 287 | sampling_params.stop = ( 288 | list(stop_sequences) if isinstance(stop_sequences, list) else [] 289 | ) 290 | 291 | request_id = uuid4().hex 292 | 293 | generator = self.engine.generate( 294 | prompt, 295 | sampling_params, 296 | request_id, 297 | ) 298 | start = 0 299 | 300 | async for result in generator: 301 | assert ( 302 | len(result.outputs) == 1 303 | ), "Expected exactly one output from generation request." 304 | 305 | if result.outputs[0].finish_reason == "length" and start != 0: 306 | # hard to find the max length though, sorry 307 | raise UserError( 308 | "E1002 PromptTooLong: Prompt length exceeds maximum input length" 309 | ) 310 | text = result.outputs[0].text 311 | 312 | # Normalize text by removing any incomplete surrogate pairs (common with emojis) 313 | text = text.replace("\N{REPLACEMENT CHARACTER}", "") 314 | 315 | yield text[start:] 316 | 317 | start = len(text) 318 | 319 | # pylint: disable=no-member 320 | self.log(f"Generation took {time.time() - start:.2f}s") 321 | self.log(f"Formatted prompt: {prompt}") 322 | self.log(f"Random seed used: `{seed}`\n") 323 | self.log( 324 | "Note: Random seed will not impact output if greedy decoding is used.\n" 325 | ) 326 | 327 | if not self._testing: 328 | # pylint: disable=no-member, undefined-loop-variable 329 | cog.emit_metric("input_token_count", len(result.prompt_token_ids)) 330 | cog.emit_metric("output_token_count", len(result.outputs[0].token_ids)) 331 | 332 | def load_config(self, weights: str) -> PredictorConfig: 333 | """ 334 | Load the predictor configuration from the specified weights directory or 335 | the current directory. 336 | 337 | Load `predictor_config.json` from the weights directory or current directory. 338 | Return a default PredictorConfig object if not found or an error occurs. 339 | 340 | Priority: 341 | 1. Load `predictor_config.json` from the specified weights directory. 342 | 2. If not found, load `predictor_config.json` from the current directory. 343 | 3. If not found or an error occurs, return a default PredictorConfig object. 344 | 345 | Args: 346 | weights (str): The path to the weights directory. 347 | 348 | Returns: 349 | PredictorConfig: The loaded predictor configuration. 350 | """ 351 | if os.path.exists(os.path.join(weights, "predictor_config.json")): 352 | predictor_config_path = os.path.join(weights, "predictor_config.json") 353 | elif os.path.exists("./predictor_config.json"): 354 | predictor_config_path = "./predictor_config.json" 355 | else: 356 | predictor_config_path = None 357 | if predictor_config_path: 358 | try: 359 | print("Loading predictor_config.json") 360 | with open( 361 | predictor_config_path, 362 | "r", 363 | encoding="utf-8", 364 | ) as f: 365 | config = json.load(f) 366 | # pylint: disable=attribute-defined-outside-init 367 | config = PredictorConfig(**config) 368 | except Exception as e: 369 | raise UserError(f"E1202 InvalidPredictorConfig: {e}") from e 370 | 371 | else: 372 | config = PredictorConfig() 373 | pprint(config) 374 | return config 375 | 376 | _defaults = { 377 | key: param.default.default 378 | for key, param in inspect.signature(predict).parameters.items() 379 | if hasattr(param.default, "default") 380 | } 381 | --------------------------------------------------------------------------------