├── MANIFEST.in ├── tests ├── __init__.py ├── cli │ ├── __init__.py │ ├── test_cli_generate.py │ ├── test_cli_files.py │ ├── test_cli_models.py │ ├── test_cli_hub.py │ ├── test_cli_fine_tuning.py │ ├── test_cli_predictions.py │ └── test_cli_config.py ├── test_data │ ├── test.mp4 │ └── image_dataset │ │ ├── test1.jpg │ │ ├── test2.png │ │ └── test3.jpg ├── test_models.py ├── test_artifacts.py ├── common │ ├── test_pdf.py │ ├── test_utils.py │ ├── test_dependencies.py │ ├── test_image.py │ ├── test_video.py │ └── test_viz.py ├── test_feedback.py ├── test_hub.py ├── test_dataset.py ├── test_files.py ├── test_client.py ├── test_exceptions.py └── test_agent.py ├── vlmrun ├── version.py ├── cli │ ├── __init__.py │ ├── _cli │ │ ├── __init__.py │ │ ├── generate.py │ │ ├── models.py │ │ ├── files.py │ │ ├── hub.py │ │ ├── fine_tuning.py │ │ ├── config.py │ │ ├── predictions.py │ │ └── datasets.py │ ├── README.md │ └── cli.py ├── types │ ├── __init__.py │ ├── refs.py │ └── abstract.py ├── client │ ├── __init__.py │ ├── models.py │ ├── feedback.py │ ├── executions.py │ ├── datasets.py │ ├── artifacts.py │ ├── hub.py │ ├── exceptions.py │ ├── client.py │ ├── agent.py │ ├── files.py │ └── base_requestor.py ├── constants.py └── common │ ├── pdf.py │ ├── webhook.py │ ├── pydantic.py │ ├── image.py │ ├── utils.py │ └── video.py ├── .env.template ├── requirements └── requirements.txt ├── .pre-commit-config.yaml ├── makefiles └── Makefile.admin.mk ├── Makefile ├── .github └── workflows │ ├── ci.yml │ └── python-publish.yml ├── pyproject.toml ├── .gitignore └── README.md /MANIFEST.in: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vlmrun/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.10" 2 | -------------------------------------------------------------------------------- /vlmrun/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """CLI package for vlmrun.""" 2 | -------------------------------------------------------------------------------- /vlmrun/types/__init__.py: -------------------------------------------------------------------------------- 1 | """Type definitions for VLM Run API.""" 2 | -------------------------------------------------------------------------------- /vlmrun/cli/_cli/__init__.py: -------------------------------------------------------------------------------- 1 | """Subcommands package for vlmrun CLI.""" 2 | -------------------------------------------------------------------------------- /vlmrun/client/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import VLMRun # noqa: F401 2 | -------------------------------------------------------------------------------- /.env.template: -------------------------------------------------------------------------------- 1 | VLMRUN_BASE_URL=https://api.vlm.run/v1 2 | VLMRUN_API_KEY= 3 | -------------------------------------------------------------------------------- /tests/test_data/test.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlm-run/vlmrun-python-sdk/main/tests/test_data/test.mp4 -------------------------------------------------------------------------------- /tests/test_data/image_dataset/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlm-run/vlmrun-python-sdk/main/tests/test_data/image_dataset/test1.jpg -------------------------------------------------------------------------------- /tests/test_data/image_dataset/test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlm-run/vlmrun-python-sdk/main/tests/test_data/image_dataset/test2.png -------------------------------------------------------------------------------- /tests/test_data/image_dataset/test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlm-run/vlmrun-python-sdk/main/tests/test_data/image_dataset/test3.jpg -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | cachetools 2 | IPython 3 | loguru 4 | opencv-python>=4.8.0 5 | pandas 6 | Pillow>=10.2.0 7 | pydantic>=2.5,<3 8 | pydantic_core>=2.23.4 9 | requests 10 | rich 11 | tabulate 12 | tenacity 13 | tqdm 14 | typer>=0.9.0 15 | vlmrun-hub>=0.1.28 16 | -------------------------------------------------------------------------------- /vlmrun/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | DEFAULT_BASE_URL = "https://api.vlm.run/v1" 4 | 5 | VLMRUN_CACHE_DIR = Path.home() / ".vlmrun" / "cache" 6 | VLMRUN_CACHE_DIR.mkdir(parents=True, exist_ok=True) 7 | 8 | VLMRUN_TMP_DIR = Path.home() / ".vlmrun" / "tmp" 9 | VLMRUN_TMP_DIR.mkdir(parents=True, exist_ok=True) 10 | 11 | SUPPORTED_VIDEO_FILETYPES = [".mp4", ".mov", ".avi", ".mkv", ".webm"] 12 | SUPPORTED_IMAGE_FILETYPES = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"] 13 | -------------------------------------------------------------------------------- /vlmrun/types/refs.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | from typing import Annotated 3 | 4 | 5 | ImageRef = Annotated[str, Field(..., pattern=r"^img_\w{6}$")] 6 | AudioRef = Annotated[str, Field(..., pattern=r"^aud_\w{6}$")] 7 | VideoRef = Annotated[str, Field(..., pattern=r"^vid_\w{6}$")] 8 | DocumentRef = Annotated[str, Field(..., pattern=r"^doc_\w{6}$")] 9 | ReconRef = Annotated[str, Field(..., pattern=r"^recon_\w{6}$")] 10 | UrlRef = Annotated[str, Field(..., pattern=r"^url_\w{6}$")] 11 | ArrayRef = Annotated[str, Field(..., pattern=r"^arr_\w{6}$")] 12 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | """Tests for models operations.""" 2 | 3 | from vlmrun.client.types import ModelInfo 4 | 5 | 6 | def test_list_models(mock_client): 7 | """Test listing available models.""" 8 | response = mock_client.models.list() 9 | assert isinstance(response, list) 10 | assert all(isinstance(model, ModelInfo) for model in response) 11 | assert len(response) > 0 12 | model = response[0] 13 | assert isinstance(model.model, str) 14 | assert isinstance(model.domain, str) 15 | assert model.model == "model1" 16 | assert model.domain == "test-domain" 17 | -------------------------------------------------------------------------------- /tests/test_artifacts.py: -------------------------------------------------------------------------------- 1 | """Tests for artifacts operations.""" 2 | 3 | import pytest 4 | 5 | 6 | def test_get_artifact(mock_client): 7 | """Test getting an artifact by session_id and object_id.""" 8 | response = mock_client.artifacts.get( 9 | session_id="550e8400-e29b-41d4-a716-446655440000", object_id="test-object-456" 10 | ) 11 | assert isinstance(response, bytes) 12 | assert response == b"mock artifact content" 13 | 14 | 15 | def test_list_artifacts_not_implemented(mock_client): 16 | """Test that list() raises NotImplementedError.""" 17 | with pytest.raises(NotImplementedError) as exc_info: 18 | mock_client.artifacts.list(session_id="550e8400-e29b-41d4-a716-446655440000") 19 | assert "not yet implemented" in str(exc_info.value) 20 | -------------------------------------------------------------------------------- /vlmrun/types/abstract.py: -------------------------------------------------------------------------------- 1 | """Abstract type definitions for VLM Run API.""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass 6 | from typing import Protocol, Any, TypeVar, Optional 7 | 8 | T = TypeVar("T") 9 | 10 | 11 | @dataclass 12 | class VLMRunProtocol(Protocol): 13 | """VLM Run API interface.""" 14 | 15 | api_key: Optional[str] 16 | base_url: str 17 | timeout: float 18 | files: Any 19 | datasets: Any 20 | models: Any 21 | hub: Any 22 | image: Any 23 | video: Any 24 | document: Any 25 | fine_tuning: Any 26 | feedback: Any 27 | agent: Any 28 | requestor: Any 29 | artifacts: Any 30 | 31 | def __post_init__(self) -> None: 32 | """Initialize after dataclass initialization.""" 33 | ... 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: true 3 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit autofix suggestions' 4 | 5 | repos: 6 | - repo: https://github.com/charliermarsh/ruff-pre-commit 7 | rev: 'v0.5.5' 8 | hooks: 9 | - id: ruff 10 | args: ['--fix', '--exit-non-zero-on-fix'] 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 24.3.0 14 | hooks: 15 | - id: black 16 | exclude: notebooks|^tests/test_data$ 17 | args: ['--config=./pyproject.toml'] 18 | 19 | - repo: https://github.com/pre-commit/pre-commit-hooks 20 | rev: v3.1.0 21 | hooks: 22 | - id: check-ast 23 | - id: check-docstring-first 24 | - id: check-json 25 | - id: check-merge-conflict 26 | - id: debug-statements 27 | - id: detect-private-key 28 | - id: end-of-file-fixer 29 | - id: pretty-format-json 30 | - id: trailing-whitespace 31 | - id: check-added-large-files 32 | args: ['--maxkb=100'] 33 | - id: requirements-txt-fixer 34 | -------------------------------------------------------------------------------- /tests/common/test_pdf.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import pytest 3 | from PIL import Image 4 | from vlmrun.common.pdf import pdf_images 5 | from loguru import logger 6 | 7 | 8 | # URL of the PDF to be tested 9 | PDF_URL = "https://storage.googleapis.com/vlm-data-public-prod/hub/examples/document.bank-statement/lending_bankstatement.pdf" 10 | 11 | 12 | @pytest.fixture 13 | def pdf_file(tmp_path): 14 | """Download the PDF and save it to a temporary file.""" 15 | response = requests.get(PDF_URL) 16 | response.raise_for_status() 17 | pdf_path = tmp_path / "test_document.pdf" 18 | pdf_path.write_bytes(response.content) 19 | return pdf_path 20 | 21 | 22 | def test_pdf_images(pdf_file): 23 | """Test that pdf_images yields an iterator of valid images.""" 24 | pages = pdf_images(pdf_file, dpi=72) 25 | for page in pages: 26 | assert isinstance(page.image, Image.Image) 27 | assert page.image.mode == "RGB" 28 | assert page.image.size[0] > 0 and page.image.size[1] > 0 29 | logger.debug(f"page={page}") 30 | -------------------------------------------------------------------------------- /makefiles/Makefile.admin.mk: -------------------------------------------------------------------------------- 1 | VLMRUN_SDK_VERSION := $(shell python -c 'from vlmrun.version import __version__; print(__version__.replace("-", "."))') 2 | PYPI_USERNAME := 3 | PYPI_PASSWORD := 4 | 5 | WHL_GREP_PATTERN := .*\$(VLMRUN_SDK_VERSION).*\.whl 6 | 7 | create-pypi-release-test: 8 | @echo "looking for vlmrun whl file..." 9 | @for file in dist/*; do \ 10 | echo "examining file: $$file"; \ 11 | if [ -f "$$file" ] && echo "$$file" | grep -qE "$(WHL_GREP_PATTERN)"; then \ 12 | echo "Uploading: $$file"; \ 13 | twine upload --repository testpypi "$$file"; \ 14 | fi; \ 15 | done 16 | @echo "Upload completed" 17 | 18 | 19 | create-pypi-release: 20 | @echo "looking for vlmrun whl file..." 21 | @for file in dist/*; do \ 22 | echo "examining file: $$file"; \ 23 | if [ -f "$$file" ] && echo "$$file" | grep -qE "$(WHL_GREP_PATTERN)"; then \ 24 | echo "Uploading: $$file"; \ 25 | twine upload "$$file"; \ 26 | fi; \ 27 | done 28 | @echo "Upload completed" 29 | 30 | create-tag: 31 | git tag -a ${VLMRUN_SDK_VERSION} -m "Release ${VLMRUN_SDK_VERSION}" 32 | git push origin ${VLMRUN_SDK_VERSION} 33 | -------------------------------------------------------------------------------- /tests/test_feedback.py: -------------------------------------------------------------------------------- 1 | """Tests for feedback operations.""" 2 | 3 | from pydantic import BaseModel 4 | from vlmrun.client.types import FeedbackListResponse, FeedbackSubmitResponse 5 | 6 | 7 | class TestLabel(BaseModel): 8 | """Test label model.""" 9 | 10 | score: int 11 | comment: str 12 | 13 | 14 | def test_submit_feedback(mock_client): 15 | """Test submitting feedback for a prediction.""" 16 | response = mock_client.feedback.submit( 17 | request_id="prediction1", 18 | response={"score": 5, "comment": "Great prediction!"}, 19 | notes="Test feedback", 20 | ) 21 | assert isinstance(response, FeedbackSubmitResponse) 22 | assert response.request_id == "prediction1" 23 | assert response.id 24 | 25 | 26 | def test_get_feedback(mock_client): 27 | """Test getting feedback for a prediction.""" 28 | response = mock_client.feedback.get("prediction1", limit=5, offset=0) 29 | assert isinstance(response, FeedbackListResponse) 30 | assert response.request_id == "prediction1" 31 | assert len(response.items) >= 0 32 | assert response.items[0].id 33 | -------------------------------------------------------------------------------- /tests/common/test_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from vlmrun.common.utils import download_artifact, create_archive 4 | 5 | PDF_URL = "https://storage.googleapis.com/vlm-data-public-prod/hub/examples/document.bank-statement/lending_bankstatement.pdf" 6 | 7 | 8 | def test_download_artifact(): 9 | """Test that download_artifact can download a PDF.""" 10 | pdf = download_artifact(PDF_URL, "file") 11 | assert pdf.exists() 12 | 13 | 14 | def test_create_archive(): 15 | """Test that create_archive can create a tar.gz file.""" 16 | import tarfile 17 | 18 | archive_path: Path = create_archive( 19 | Path(__file__).parent.parent / "test_data/image_dataset", "test_image_dataset" 20 | ) 21 | assert archive_path.exists() 22 | assert archive_path.name.endswith(".tar.gz") 23 | 24 | # Unzip the archive and check if there is a folder with the same name as the stem 25 | stem = archive_path.name.replace(".tar.gz", "") 26 | with tarfile.open(archive_path, "r:gz") as tar: 27 | assert len(tar.getmembers()) == 4 # basedir + 3 images 28 | assert tar.getmembers()[0].name == stem 29 | -------------------------------------------------------------------------------- /tests/cli/test_cli_generate.py: -------------------------------------------------------------------------------- 1 | """Test generate subcommand.""" 2 | 3 | from pathlib import Path 4 | 5 | from vlmrun.cli.cli import app 6 | from vlmrun.common.utils import download_artifact 7 | 8 | 9 | def test_generate_image(runner, mock_client, config_file, tmp_path): 10 | """Test generate image command.""" 11 | path: Path = download_artifact( 12 | "https://storage.googleapis.com/vlm-data-public-prod/hub/examples/document.invoice/invoice_1.jpg", 13 | format="file", 14 | ) 15 | result = runner.invoke( 16 | app, ["generate", "image", str(path), "--domain", "document.invoice"] 17 | ) 18 | assert result.exit_code == 0 19 | 20 | 21 | def test_generate_document(runner, mock_client, config_file, tmp_path): 22 | """Test generate document command.""" 23 | path: Path = download_artifact( 24 | "https://storage.googleapis.com/vlm-data-public-prod/hub/examples/document.bank-statement/lending_bankstatement.pdf", 25 | format="file", 26 | ) 27 | result = runner.invoke( 28 | app, ["generate", "document", str(path), "--domain", "document.bank-statement"] 29 | ) 30 | assert result.exit_code == 0 31 | -------------------------------------------------------------------------------- /tests/cli/test_cli_files.py: -------------------------------------------------------------------------------- 1 | """Test files subcommand.""" 2 | 3 | from vlmrun.cli.cli import app 4 | 5 | 6 | def test_list_files(runner, mock_client, config_file): 7 | """Test list files command.""" 8 | result = runner.invoke(app, ["files", "list"]) 9 | assert result.exit_code == 0 10 | assert "file1" in result.stdout 11 | assert "test.txt" in result.stdout 12 | 13 | 14 | def test_upload_file(runner, mock_client, config_file, tmp_path): 15 | """Test upload file command.""" 16 | test_file = tmp_path / "test.txt" 17 | test_file.write_text("test content") 18 | result = runner.invoke(app, ["files", "upload", str(test_file)]) 19 | assert result.exit_code == 0 20 | assert "file1" in result.stdout 21 | 22 | 23 | def test_delete_file(runner, mock_client, config_file): 24 | """Test delete file command.""" 25 | result = runner.invoke(app, ["files", "delete", "file1"]) 26 | assert result.exit_code == 0 27 | 28 | 29 | def test_get_file(runner, mock_client, config_file): 30 | """Test get file command.""" 31 | result = runner.invoke(app, ["files", "get", "file1"]) 32 | assert result.exit_code == 0 33 | assert "test content" in result.stdout 34 | -------------------------------------------------------------------------------- /vlmrun/client/models.py: -------------------------------------------------------------------------------- 1 | """VLM Run API Models resource.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import List 6 | 7 | from vlmrun.client.base_requestor import APIRequestor 8 | from vlmrun.types.abstract import VLMRunProtocol 9 | from vlmrun.client.types import ModelInfo 10 | 11 | 12 | class Models: 13 | """Models resource for VLM Run API.""" 14 | 15 | def __init__(self, client: "VLMRunProtocol") -> None: 16 | """Initialize Models resource with VLMRun instance. 17 | 18 | Args: 19 | client: VLM Run API instance 20 | """ 21 | self._client = client 22 | self._requestor = APIRequestor(client) 23 | 24 | def list(self) -> List[ModelInfo]: 25 | """List available models. 26 | 27 | Returns: 28 | List[ModelResponse]: List of model objects with their capabilities 29 | """ 30 | response, status_code, headers = self._requestor.request( 31 | method="GET", 32 | url="models", 33 | ) 34 | 35 | if not isinstance(response, list): 36 | raise TypeError("Expected list response") 37 | return [ModelInfo(**model) for model in response] 38 | -------------------------------------------------------------------------------- /tests/test_hub.py: -------------------------------------------------------------------------------- 1 | """Tests for hub operations.""" 2 | 3 | from vlmrun.client.types import ( 4 | HubInfoResponse, 5 | HubSchemaResponse, 6 | HubDomainInfo, 7 | ) 8 | 9 | 10 | def test_hub_info(mock_client): 11 | """Test hub info retrieval.""" 12 | client = mock_client 13 | response = client.hub.info() 14 | assert isinstance(response, HubInfoResponse) 15 | assert isinstance(response.version, str) 16 | assert response.version == "0.1.0" 17 | 18 | 19 | def test_hub_list_domains(mock_client): 20 | """Test listing hub domains.""" 21 | client = mock_client 22 | response = client.hub.list_domains() 23 | assert isinstance(response, list) 24 | assert all(isinstance(domain, HubDomainInfo) for domain in response) 25 | assert "document.invoice" in [domain.domain for domain in response] 26 | 27 | 28 | def test_hub_get_schema(mock_client): 29 | """Test schema retrieval for a domain.""" 30 | client = mock_client 31 | domain = "document.invoice" 32 | response = client.hub.get_schema(domain) 33 | assert isinstance(response, HubSchemaResponse) 34 | assert isinstance(response.json_schema, dict) 35 | assert isinstance(response.schema_version, str) 36 | assert isinstance(response.schema_hash, str) 37 | -------------------------------------------------------------------------------- /tests/cli/test_cli_models.py: -------------------------------------------------------------------------------- 1 | """Test models subcommand.""" 2 | 3 | from vlmrun.cli.cli import app 4 | 5 | 6 | def test_list_models(runner, mock_client, config_file): 7 | """Test list models command.""" 8 | result = runner.invoke(app, ["models", "list"]) 9 | assert result.exit_code == 0 10 | assert "model1" in result.stdout 11 | assert "test-domain" in result.stdout 12 | assert "Available Models" in result.stdout 13 | 14 | 15 | def test_list_models_with_filter(runner, mock_client, config_file): 16 | """Test list models command with domain filter.""" 17 | result = runner.invoke(app, ["models", "list", "--domain", "test-domain"]) 18 | assert result.exit_code == 0 19 | assert "model1" in result.stdout 20 | assert "test-domain" in result.stdout 21 | 22 | result = runner.invoke(app, ["models", "list", "--domain", "nonexistent"]) 23 | assert result.exit_code == 0 24 | assert "model1" not in result.stdout 25 | 26 | 27 | def test_list_models_formatting(runner, mock_client, config_file): 28 | """Test that list models output is properly formatted.""" 29 | result = runner.invoke(app, ["models", "list"]) 30 | assert result.exit_code == 0 31 | assert "Category" in result.stdout 32 | assert "Model" in result.stdout 33 | assert "Domain" in result.stdout 34 | -------------------------------------------------------------------------------- /tests/cli/test_cli_hub.py: -------------------------------------------------------------------------------- 1 | """Test hub subcommand.""" 2 | 3 | from vlmrun.cli.cli import app 4 | 5 | 6 | def test_hub_version(runner, mock_client, config_file): 7 | """Test hub version command.""" 8 | result = runner.invoke(app, ["hub", "version"]) 9 | assert result.exit_code == 0 10 | assert "0.1.0" in result.stdout 11 | assert "github.com/vlm-run/vlmrun-hub" in result.stdout 12 | 13 | 14 | def test_hub_list_domains(runner, mock_client, config_file): 15 | """Test listing hub domains.""" 16 | result = runner.invoke(app, ["hub", "list"]) 17 | assert result.exit_code == 0 18 | assert "document.invoice" in result.stdout 19 | assert "document.receipt" in result.stdout 20 | assert "document.utility_bill" in result.stdout 21 | 22 | 23 | def test_hub_list_domains_with_filter(runner, mock_client, config_file): 24 | """Test listing hub domains with filter.""" 25 | result = runner.invoke(app, ["hub", "list", "--domain", "document.invoice"]) 26 | assert result.exit_code == 0 27 | assert "document.invoice" in result.stdout 28 | assert "document.receipt" not in result.stdout 29 | 30 | 31 | def test_hub_schema(runner, mock_client, config_file): 32 | """Test getting schema for a domain.""" 33 | result = runner.invoke(app, ["hub", "schema", "document.invoice"]) 34 | assert result.exit_code == 0 35 | assert "Schema Information" in result.stdout 36 | assert "invoice_number" in result.stdout 37 | assert "total_amount" in result.stdout 38 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | """Tests for dataset operations.""" 2 | 3 | import pytest 4 | from datetime import datetime 5 | from vlmrun.client.types import DatasetResponse 6 | 7 | 8 | def test_dataset_create(mock_client): 9 | """Test dataset creation.""" 10 | response = mock_client.dataset.create( 11 | file_id="file1", 12 | domain="test-domain", 13 | dataset_name="test-dataset", 14 | dataset_type="images", 15 | ) 16 | assert isinstance(response, DatasetResponse) 17 | assert response.id == "dataset1" 18 | assert response.domain == "test-domain" 19 | assert response.dataset_type == "images" 20 | assert isinstance(response.created_at, datetime) 21 | 22 | 23 | def test_dataset_get(mock_client): 24 | """Test dataset retrieval.""" 25 | response = mock_client.dataset.get("dataset1") 26 | assert isinstance(response, DatasetResponse) 27 | assert response.id == "dataset1" 28 | assert response.domain == "test-domain" 29 | assert response.dataset_type == "images" 30 | assert isinstance(response.created_at, datetime) 31 | 32 | 33 | def test_dataset_invalid_type(mock_client): 34 | """Test dataset creation with invalid type.""" 35 | with pytest.raises(ValueError) as exc_info: 36 | mock_client.dataset.create( 37 | file_id="file1", 38 | domain="test-domain", 39 | dataset_name="test-dataset", 40 | dataset_type="invalid", 41 | ) 42 | assert "dataset_type must be one of: images, videos, documents" in str( 43 | exc_info.value 44 | ) 45 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | default: help; 2 | 3 | help: 4 | @echo "🔥 Official VLM Run Python SDK" 5 | @echo "" 6 | @echo "Usage: make " 7 | @echo "" 8 | @echo "Targets:" 9 | @echo " clean Remove all build, test, coverage and Python artifacts" 10 | @echo " clean-build Remove build artifacts" 11 | @echo " clean-pyc Remove Python file artifacts" 12 | @echo " clean-test Remove test and coverage artifacts" 13 | @echo " lint Format source code automatically" 14 | @echo " test Basic testing" 15 | @echo " dist Builds source and wheel package" 16 | @echo "" 17 | 18 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 19 | 20 | clean-build: ## remove build artifacts 21 | rm -fr build/ 22 | rm -fr dist/ 23 | rm -fr .eggs/ 24 | rm -fr site/ 25 | find . -name '*.egg-info' -exec rm -fr {} + 26 | find . -name '*.egg' -exec rm -f {} + 27 | 28 | 29 | clean-pyc: ## remove Python file artifacts 30 | find . -name '*.pyc' -exec rm -f {} + 31 | find . -name '*.pyo' -exec rm -f {} + 32 | find . -name '*~' -exec rm -f {} + 33 | find . -name '__pycache__' -exec rm -fr {} + 34 | 35 | clean-test: ## remove test and coverage artifacts 36 | rm -fr .tox/ 37 | rm -f .coverage 38 | rm -fr htmlcov/ 39 | rm -fr .pytest_cache 40 | 41 | lint: ## Format source code automatically 42 | pre-commit run --all-files # Uses pyproject.toml 43 | 44 | test: ## Basic CPU testing 45 | pytest -sv tests 46 | 47 | dist: clean ## builds source and wheel package 48 | python -m build --sdist --wheel 49 | ls -lh dist 50 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Main CI 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | workflow_call: 8 | 9 | jobs: 10 | test: 11 | name: Test 12 | runs-on: ubuntu-latest 13 | timeout-minutes: 20 14 | environment: dev 15 | strategy: 16 | matrix: 17 | python-version: ['3.9', '3.10', '3.11', '3.12'] 18 | defaults: 19 | run: 20 | shell: bash -el {0} 21 | env: 22 | CACHE_NUMBER: 0 23 | 24 | steps: 25 | - name: Checkout git repo 26 | uses: actions/checkout@v3 27 | 28 | - uses: actions/setup-python@v5 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - uses: actions/cache@v4 33 | with: 34 | path: | 35 | ~/.cache/uv 36 | ~/.cache/pip 37 | key: uv-${{ matrix.python-version }}-${{ hashFiles('requirements/requirements*.txt') }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('MANIFEST.in') }}-${{ env.CACHE_NUMBER }} 38 | restore-keys: | 39 | uv-${{ matrix.python-version }}- 40 | 41 | - name: Install base dependencies 42 | run: | 43 | which python 44 | pip install uv 45 | uv pip install --system -e '.[test,all]' 46 | 47 | - name: Quality Check 48 | uses: pre-commit/action@v3.0.1 49 | continue-on-error: true 50 | 51 | - name: Run all tests 52 | run: | 53 | # Install all extras and run complete test suite 54 | uv pip install --system -e '.[all]' 55 | pytest tests/common/test_dependencies.py -v 56 | make test 57 | -------------------------------------------------------------------------------- /vlmrun/cli/_cli/generate.py: -------------------------------------------------------------------------------- 1 | """Generation API commands.""" 2 | 3 | from pathlib import Path 4 | 5 | import typer 6 | from PIL import Image 7 | from rich import print as rprint 8 | 9 | from vlmrun.client import VLMRun 10 | from vlmrun.client.types import PredictionResponse 11 | from vlmrun.common.image import _open_image_with_exif 12 | 13 | app = typer.Typer(help="Generation operations", no_args_is_help=True) 14 | 15 | 16 | @app.command() 17 | def image( 18 | ctx: typer.Context, 19 | image: Path = typer.Argument( 20 | ..., help="Input image file", exists=True, readable=True 21 | ), 22 | domain: str = typer.Option( 23 | ..., help="Domain to use for generation (e.g. `document.invoice`)" 24 | ), 25 | ) -> None: 26 | """Generate an image.""" 27 | client: VLMRun = ctx.obj 28 | if not Path(image).is_file(): 29 | raise typer.Abort(f"Image file does not exist: {image}") 30 | 31 | img: Image.Image = _open_image_with_exif(image) 32 | response: PredictionResponse = client.image.generate(images=[img], domain=domain) 33 | rprint(response) 34 | 35 | 36 | @app.command() 37 | def document( 38 | ctx: typer.Context, 39 | path: Path = typer.Argument( 40 | ..., help="Path to the document file", exists=True, readable=True 41 | ), 42 | domain: str = typer.Option( 43 | ..., help="Domain to use for generation (e.g. `document.invoice`)" 44 | ), 45 | ) -> None: 46 | """Generate a document.""" 47 | client: VLMRun = ctx.obj 48 | if not Path(path).is_file(): 49 | raise typer.Abort(f"Document file does not exist: {path}") 50 | 51 | response = client.document.generate(file=path, domain=domain) 52 | rprint(response) 53 | -------------------------------------------------------------------------------- /vlmrun/common/pdf.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Iterable, Literal, Tuple, Optional, Union 4 | 5 | from PIL import Image 6 | 7 | from loguru import logger 8 | 9 | 10 | @dataclass 11 | class Page: 12 | image: Image.Image 13 | """The image of the page.""" 14 | page_number: int 15 | """The page number of the page.""" 16 | 17 | 18 | def pdf_images( 19 | path: Path, 20 | dpi: int = 150, 21 | pages: Optional[Tuple[int]] = None, 22 | backend: Literal["pypdfium2"] = "pypdfium2", 23 | ) -> Iterable[Union[Image.Image, Page]]: 24 | """Extract images from a PDF file.""" 25 | if path.suffix not in (".pdf",): 26 | raise ValueError( 27 | f"Unsupported file type: {path.suffix}. Supported types are .pdf" 28 | ) 29 | 30 | logger.debug( 31 | f"Extracting images from PDF [path={path}, dpi={dpi}, size={path.stat().st_size / 1024 / 1024:.2f} MB]" 32 | ) 33 | 34 | if backend == "pypdfium2": 35 | import pypdfium2 as pdfium 36 | 37 | logger.debug(f"Opening PDF document [path={path}]") 38 | doc = pdfium.PdfDocument(str(path)) 39 | 40 | if pages is None: 41 | pages = range(len(doc)) 42 | 43 | def iterator(): 44 | for idx in pages: 45 | page = doc[idx] 46 | bitmap = page.render(scale=dpi / 72) 47 | image: Image.Image = bitmap.to_pil() 48 | yield Page(image=image, page_number=idx) 49 | 50 | yield from iterator() 51 | 52 | logger.debug(f"Closing PDF document [path={path}]") 53 | doc.close() 54 | else: 55 | raise ValueError(f"Unsupported backend: {backend}") 56 | -------------------------------------------------------------------------------- /tests/cli/test_cli_fine_tuning.py: -------------------------------------------------------------------------------- 1 | """Test fine-tuning subcommand.""" 2 | 3 | import pytest 4 | from vlmrun.cli.cli import app 5 | 6 | pytestmark = pytest.mark.skip(reason="Not implemented") 7 | 8 | 9 | def test_create_finetune(runner, mock_client): 10 | """Test create fine-tuning command.""" 11 | result = runner.invoke(app, ["fine-tuning", "create", "file1", "test-model"]) 12 | assert result.exit_code == 0 13 | assert "job1" in result.stdout 14 | 15 | 16 | def test_list_finetune(runner, mock_client): 17 | """Test list fine-tuning command.""" 18 | result = runner.invoke(app, ["fine-tuning", "list"]) 19 | assert result.exit_code == 0 20 | assert "job1" in result.stdout 21 | assert "test-model" in result.stdout 22 | 23 | 24 | def test_provision_finetune(runner, mock_client): 25 | """Test provision fine-tuning command.""" 26 | result = runner.invoke(app, ["fine-tuning", "provision", "test-model"]) 27 | assert result.exit_code == 0 28 | assert "provisioned" in result.stdout 29 | 30 | 31 | def test_get_finetune(runner, mock_client): 32 | """Test get fine-tuning command.""" 33 | result = runner.invoke(app, ["fine-tuning", "get", "job1"]) 34 | assert result.exit_code == 0 35 | assert "running" in result.stdout 36 | 37 | 38 | def test_cancel_finetune(runner, mock_client): 39 | """Test cancel fine-tuning command.""" 40 | result = runner.invoke(app, ["fine-tuning", "cancel", "job1"]) 41 | assert result.exit_code == 0 42 | 43 | 44 | def test_status_finetune(runner, mock_client): 45 | """Test status fine-tuning command.""" 46 | result = runner.invoke(app, ["fine-tuning", "status", "job1"]) 47 | assert result.exit_code == 0 48 | assert "running" in result.stdout 49 | -------------------------------------------------------------------------------- /tests/common/test_dependencies.py: -------------------------------------------------------------------------------- 1 | """Tests for verifying correct installation of optional dependencies.""" 2 | 3 | import pytest 4 | 5 | 6 | @pytest.mark.skip(reason="Temporarily skipped as requested") 7 | def test_base_dependencies(): 8 | """Verify base installation has no optional dependencies.""" 9 | with pytest.raises(ImportError): 10 | import cv2 # noqa: F401 11 | 12 | with pytest.raises(ImportError): 13 | import pypdfium2 # noqa: F401 14 | 15 | 16 | @pytest.mark.skip(reason="Temporarily skipped as requested") 17 | def test_video_dependencies(): 18 | """Verify video dependencies are available.""" 19 | import cv2 # noqa: F401 20 | import numpy as np # noqa: F401 21 | 22 | # Verify we can import and get versions 23 | assert cv2.__version__, "cv2 version should be available" 24 | assert np.__version__, "numpy version should be available" 25 | 26 | 27 | @pytest.mark.skip(reason="Temporarily skipped as requested") 28 | def test_doc_dependencies(): 29 | """Verify doc dependencies are available.""" 30 | import pypdfium2 # noqa: F401 31 | 32 | # Verify we can import and get version 33 | assert pypdfium2.__version__, "pypdfium2 version should be available" 34 | 35 | 36 | @pytest.mark.skip(reason="Temporarily skipped as requested") 37 | def test_all_dependencies(): 38 | """Verify all dependencies are available.""" 39 | import cv2 # noqa: F401 40 | import numpy as np # noqa: F401 41 | import pypdfium2 # noqa: F401 42 | 43 | # Verify we can import and get versions 44 | assert cv2.__version__, "cv2 version should be available" 45 | assert np.__version__, "numpy version should be available" 46 | assert pypdfium2.__version__, "pypdfium2 version should be available" 47 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | [build-system] 3 | requires = ["setuptools>=40.1.0", "wheel"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [project] 7 | name = "vlmrun" 8 | description = "Official Python SDK for VLM Run" 9 | authors = [{name = "VLM Support", email = "support@vlm.com"}] 10 | readme = "README.md" 11 | requires-python = ">=3.9" 12 | classifiers = [ 13 | "Development Status :: 4 - Beta", 14 | "Programming Language :: Python", 15 | "Environment :: Console", 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: Education", 18 | "Intended Audience :: Information Technology", 19 | "Intended Audience :: Science/Research","Topic :: Software Development :: Libraries", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 21 | "Topic :: Scientific/Engineering :: Image Processing", 22 | "License :: OSI Approved :: Apache Software License", 23 | "Operating System :: POSIX :: Linux", 24 | "Operating System :: MacOS", 25 | "Programming Language :: Python :: 3 :: Only", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | ] 29 | license = {text = "Apache-2.0"} 30 | dynamic = ["version", "dependencies"] 31 | 32 | [project.optional-dependencies] 33 | test = ["pytest", "openai", "pre-commit"] 34 | build = ["twine", "build"] 35 | openai = ["openai>=1.0.0"] 36 | video = [ 37 | "numpy>=1.24.0", 38 | ] 39 | doc = [ 40 | "pypdfium2>=4.30.0" 41 | ] 42 | all = [ 43 | "numpy>=1.24.0", 44 | "pypdfium2>=4.30.0", 45 | "openai>=1.0.0", 46 | ] 47 | 48 | [tool.setuptools.dynamic] 49 | version = {attr = "vlmrun.version.__version__"} 50 | dependencies = {file = ["requirements/requirements.txt"]} 51 | 52 | [tool.setuptools.packages.find] 53 | include = ["vlmrun*"] 54 | namespaces = true 55 | 56 | [tool.pytest.ini_options] 57 | testpaths = ["tests"] 58 | python_files = "test_*.py" 59 | addopts = "-v --tb=short" 60 | 61 | [project.scripts] 62 | vlmrun = "vlmrun.cli.cli:app" 63 | -------------------------------------------------------------------------------- /vlmrun/cli/_cli/models.py: -------------------------------------------------------------------------------- 1 | """Models API commands.""" 2 | 3 | from typing import List 4 | 5 | import typer 6 | from rich.table import Table 7 | from rich.console import Console 8 | 9 | from vlmrun.client import VLMRun 10 | from vlmrun.client.types import ModelInfo 11 | 12 | app = typer.Typer( 13 | help="Model operations", 14 | add_completion=False, 15 | no_args_is_help=True, 16 | ) 17 | 18 | 19 | @app.command() 20 | def list( 21 | ctx: typer.Context, 22 | domain: str = typer.Option( 23 | None, help="Filter domains (e.g. 'document' or 'document.invoice')" 24 | ), 25 | ) -> None: 26 | """List available models.""" 27 | client: VLMRun = ctx.obj 28 | models: List[ModelInfo] = client.models.list() 29 | 30 | if domain: 31 | models = [m for m in models if domain in m.domain] 32 | 33 | console = Console() 34 | table = Table( 35 | show_header=True, 36 | header_style="white", 37 | border_style="white", 38 | title="Available Models", 39 | title_style="white", 40 | expand=True, 41 | ) 42 | 43 | table.add_column("Category") 44 | table.add_column("Model") 45 | table.add_column("Domain") 46 | 47 | domain_groups = {} 48 | for model in models: 49 | category = model.domain.split(".")[0] 50 | if category not in domain_groups: 51 | domain_groups[category] = [] 52 | domain_groups[category].append(model) 53 | 54 | for category in sorted(domain_groups.keys()): 55 | first_in_category = True 56 | for model in sorted(domain_groups[category], key=lambda x: x.domain): 57 | if first_in_category: 58 | table.add_row(category, model.model, model.domain, style="white") 59 | first_in_category = False 60 | else: 61 | table.add_row("", model.model, model.domain, style="white") 62 | if category != sorted(domain_groups.keys())[-1]: 63 | table.add_row("", "", "", style="dim white") 64 | 65 | console.print(table) 66 | -------------------------------------------------------------------------------- /vlmrun/common/webhook.py: -------------------------------------------------------------------------------- 1 | """Webhook verification utilities for VLMRun.""" 2 | 3 | import hmac 4 | import hashlib 5 | from typing import Optional, Union 6 | 7 | 8 | def verify_webhook( 9 | raw_body: Union[str, bytes], signature_header: Optional[str], secret: str 10 | ) -> bool: 11 | """ 12 | Verify webhook HMAC signature. 13 | 14 | This function verifies that a webhook request came from VLM Run by validating 15 | the HMAC signature in the X-VLMRun-Signature header. The signature is computed 16 | using SHA256 HMAC with your webhook secret. 17 | 18 | Args: 19 | raw_body: Raw request body as string or bytes 20 | signature_header: X-VLMRun-Signature header value (format: "sha256=") 21 | secret: Your webhook secret from VLM Run dashboard 22 | 23 | Returns: 24 | True if the signature is valid, False otherwise 25 | 26 | Example: 27 | ```python 28 | from fastapi import FastAPI, Request, HTTPException 29 | from vlmrun.common.webhook import verify_webhook 30 | 31 | app = FastAPI() 32 | 33 | @app.post("/webhook") 34 | async def webhook_endpoint(request: Request): 35 | raw_body = await request.body() 36 | signature_header = request.headers.get("X-VLMRun-Signature") 37 | secret = "your_webhook_secret_here" 38 | 39 | if not verify_webhook(raw_body, signature_header, secret): 40 | raise HTTPException(status_code=401, detail="Invalid signature") 41 | 42 | import json 43 | data = json.loads(raw_body) 44 | return {"status": "success"} 45 | ``` 46 | """ 47 | if not signature_header or not signature_header.startswith("sha256="): 48 | return False 49 | 50 | if not secret: 51 | return False 52 | 53 | received_sig = signature_header.replace("sha256=", "") 54 | 55 | body_bytes = raw_body.encode("utf-8") if isinstance(raw_body, str) else raw_body 56 | 57 | expected_sig = hmac.new( 58 | secret.encode("utf-8"), body_bytes, hashlib.sha256 59 | ).hexdigest() 60 | 61 | return hmac.compare_digest(received_sig, expected_sig) 62 | -------------------------------------------------------------------------------- /tests/test_files.py: -------------------------------------------------------------------------------- 1 | """Tests for files operations.""" 2 | 3 | from vlmrun.client.types import FileResponse, PresignedUrlRequest, PresignedUrlResponse 4 | 5 | 6 | def test_list_files(mock_client): 7 | """Test listing files.""" 8 | response = mock_client.files.list() 9 | assert isinstance(response, list) 10 | assert all(isinstance(file, FileResponse) for file in response) 11 | assert len(response) > 0 12 | assert response[0].id == "file1" 13 | assert response[0].filename == "test.txt" 14 | 15 | 16 | def test_upload_file(mock_client, tmp_path): 17 | """Test uploading a file.""" 18 | # Create a temporary file 19 | test_file = tmp_path / "test.txt" 20 | test_file.write_text("test content") 21 | 22 | response = mock_client.files.upload(test_file) 23 | assert isinstance(response, FileResponse) 24 | assert response.id == "file1" 25 | assert response.filename == str(test_file) 26 | 27 | 28 | def test_get_file(mock_client): 29 | """Test getting file metadata.""" 30 | response = mock_client.files.get("file1") 31 | assert isinstance(response, FileResponse) 32 | assert response.id == "file1" 33 | assert response.filename == "test.txt" 34 | assert response.bytes == 10 35 | 36 | 37 | def test_delete_file(mock_client): 38 | """Test deleting a file.""" 39 | response = mock_client.files.delete("file1") 40 | assert isinstance(response, FileResponse) 41 | assert response.id == "file1" 42 | 43 | 44 | def test_generate_presigned_url(mock_client): 45 | """Test generating a presigned URL.""" 46 | params = PresignedUrlRequest(filename="test.pdf", purpose="assistants") 47 | response = mock_client.files.generate_presigned_url(params) 48 | assert isinstance(response, PresignedUrlResponse) 49 | assert response.id == "presigned1" 50 | assert response.filename == "test.pdf" 51 | assert response.url is not None 52 | assert "storage.example.com" in response.url 53 | 54 | 55 | def test_generate_presigned_url_without_purpose(mock_client): 56 | """Test generating a presigned URL without specifying purpose.""" 57 | params = PresignedUrlRequest(filename="document.pdf") 58 | response = mock_client.files.generate_presigned_url(params) 59 | assert isinstance(response, PresignedUrlResponse) 60 | assert response.filename == "document.pdf" 61 | -------------------------------------------------------------------------------- /vlmrun/common/pydantic.py: -------------------------------------------------------------------------------- 1 | """Pydantic utilities for VLMRun.""" 2 | 3 | from datetime import date, datetime, time, timedelta 4 | from typing import Any, List, Type, Union, get_origin 5 | 6 | from pydantic import BaseModel, create_model 7 | 8 | 9 | def patch_response_format(response_format: Type[BaseModel]) -> Type[BaseModel]: 10 | """Patch the OpenAI response format to handle Pydantic models, including nested models. 11 | 12 | The following fields are not supported by OpenAI: 13 | - date 14 | - datetime 15 | - time 16 | - timedelta 17 | 18 | This function patches the response format to handle these fields. We convert them to strings and 19 | then convert them back to the original type. 20 | """ 21 | 22 | def patch_pydantic_field_annotation(annotation: Any) -> Any: 23 | """Convert unsupported types to string.""" 24 | if annotation in [date, datetime, time, timedelta]: 25 | return str 26 | elif get_origin(annotation) is Union: 27 | # For Union types, convert each argument 28 | args = annotation.__args__ 29 | if not args: 30 | return str 31 | # Convert the first type 32 | first_type = patch_pydantic_field_annotation(args[0]) 33 | # For single-type unions, combine with str 34 | if len(args) == 1: 35 | return Union[first_type, str] 36 | # For multi-type unions, use first two types 37 | second_type = patch_pydantic_field_annotation(args[1]) 38 | return Union[first_type, second_type] 39 | elif get_origin(annotation) is List: 40 | return List[patch_pydantic_field_annotation(annotation.__args__[0])] 41 | elif isinstance(annotation, type) and issubclass(annotation, BaseModel): 42 | return patch_pydantic_model(annotation) 43 | return annotation 44 | 45 | def patch_pydantic_model(model: Type[BaseModel]) -> Type[BaseModel]: 46 | """Create a new model with patched field types.""" 47 | fields = model.model_fields.copy() 48 | new_fields = { 49 | field_name: (patch_pydantic_field_annotation(field.annotation), field) 50 | for field_name, field in fields.items() 51 | } 52 | return create_model( 53 | f"{model.__name__}_patched", __base__=BaseModel, **new_fields 54 | ) 55 | 56 | return patch_pydantic_model(response_format) 57 | -------------------------------------------------------------------------------- /vlmrun/cli/README.md: -------------------------------------------------------------------------------- 1 | ## 🖥️ Command Line Interface 2 | 3 | VLM Run provides a powerful CLI for interacting with the platform. The CLI supports major operations including dataset management, model fine-tuning, predictions, and more. 4 | 5 | ### Installation 6 | 7 | The CLI is automatically installed with the Python package: 8 | 9 | ```bash 10 | pip install vlmrun 11 | ``` 12 | 13 | 14 | ### Configuration 15 | 16 | Configure your API key and other settings using the `vlmrun config` command: 17 | 18 | ```bash 19 | ## Set your API key 20 | vlmrun config set --api-key YOUR_API_KEY 21 | ``` 22 | 23 | View current configuration: 24 | 25 | ```bash 26 | vlmrun config show 27 | ``` 28 | 29 | 30 | ### Key Features 31 | 32 | #### Dataset Management 33 | 34 | Create a dataset from a directory of files 35 | ```bash 36 | vlmrun datasets create --directory ./images --domain document.invoice --dataset-name "Invoice Dataset" --dataset-type images 37 | ``` 38 | 39 | List your datasets 40 | ```bash 41 | vlmrun datasets list 42 | ``` 43 | 44 | #### Model Operations 45 | 46 | List available models 47 | ```bash 48 | vlmrun models list 49 | ``` 50 | 51 | Fine-tune a model 52 | ```bash 53 | vlmrun fine-tuning create --model base_model --training-file training_file_id 54 | ``` 55 | 56 | #### Predictions 57 | 58 | Generate predictions from a document 59 | ```bash 60 | vlmrun generate document path/to/document.pdf --domain document.invoice 61 | ``` 62 | 63 | List predictions 64 | ```bash 65 | vlmrun predictions list 66 | ``` 67 | 68 | Get prediction details 69 | ```bash 70 | vlmrun predictions get PREDICTION_ID 71 | ``` 72 | 73 | #### Hub Operations 74 | 75 | List available domains 76 | ```bash 77 | vlmrun hub list 78 | ``` 79 | 80 | View schema for a domain 81 | ```bash 82 | vlmrun hub schema document.invoice 83 | ``` 84 | 85 | 86 | ### Full Command Reference 87 | 88 | Here are the main command groups available: 89 | 90 | - `vlmrun config` - Manage configuration settings 91 | - `vlmrun datasets` - Dataset operations 92 | - `vlmrun files` - File management 93 | - `vlmrun fine-tuning` - Model fine-tuning operations 94 | - `vlmrun generate` - Generate predictions 95 | - `vlmrun hub` - Access domain schemas and information 96 | - `vlmrun models` - Model operations 97 | - `vlmrun predictions` - Manage and monitor predictions 98 | 99 | For detailed help on any command, use the `--help` flag: 100 | 101 | ```bash 102 | vlmrun --help 103 | vlmrun --help 104 | ``` 105 | -------------------------------------------------------------------------------- /vlmrun/cli/_cli/files.py: -------------------------------------------------------------------------------- 1 | """Files API commands.""" 2 | 3 | from pathlib import Path 4 | from typing import Optional, List 5 | 6 | import typer 7 | from rich.table import Table 8 | from rich.console import Console 9 | from rich import print as rprint 10 | from vlmrun.client import VLMRun 11 | from vlmrun.client.types import FileResponse 12 | 13 | app = typer.Typer( 14 | help="File operations", 15 | add_completion=False, 16 | no_args_is_help=True, 17 | ) 18 | 19 | 20 | @app.command() 21 | def list(ctx: typer.Context) -> None: 22 | """List all files.""" 23 | client: VLMRun = ctx.obj 24 | 25 | files: List[FileResponse] = client.files.list() 26 | console = Console() 27 | table = Table(show_header=True) 28 | table.add_column("id", min_width=40) 29 | table.add_column("filename") 30 | table.add_column("bytes") 31 | table.add_column("created_at") 32 | table.add_column("purpose") 33 | for file in files: 34 | table.add_row( 35 | file.id, 36 | file.filename, 37 | f"{file.bytes / 1024 / 1024:.2f} MB", 38 | file.created_at.strftime("%Y-%m-%d %H:%M:%S"), 39 | str(file.purpose), 40 | ) 41 | 42 | console.print(table) 43 | 44 | 45 | @app.command() 46 | def upload( 47 | ctx: typer.Context, 48 | file: Path = typer.Argument(..., help="File to upload", exists=True, readable=True), 49 | purpose: str = typer.Option( 50 | "fine-tune", 51 | help="Purpose of the file (one of: datasets, fine-tune, assistants, assistants_output, batch, batch_output, vision)", 52 | ), 53 | ) -> None: 54 | """Upload a file.""" 55 | client: VLMRun = ctx.obj 56 | result = client.files.upload(str(file), purpose=purpose) 57 | rprint(f"Uploaded file {result.filename} with ID: {result.id}") 58 | 59 | 60 | @app.command() 61 | def delete( 62 | ctx: typer.Context, 63 | file_id: str = typer.Argument(..., help="ID of the file to delete"), 64 | ) -> None: 65 | """Delete a file.""" 66 | client: VLMRun = ctx.obj 67 | client.files.delete(file_id) 68 | rprint(f"Deleted file {file_id}") 69 | 70 | 71 | @app.command() 72 | def get( 73 | ctx: typer.Context, 74 | file_id: str = typer.Argument(..., help="ID of the file to retrieve"), 75 | output: Optional[Path] = typer.Option(None, help="Output file path"), 76 | ) -> None: 77 | """Get file content.""" 78 | client: VLMRun = ctx.obj 79 | content = client.files.get_content(file_id) 80 | if output: 81 | output.write_bytes(content) 82 | rprint(f"File content written to {output}") 83 | else: 84 | rprint(content.decode()) 85 | -------------------------------------------------------------------------------- /vlmrun/client/feedback.py: -------------------------------------------------------------------------------- 1 | """VLM Run API Feedback resource.""" 2 | 3 | from typing import Optional, Dict, Any 4 | 5 | from vlmrun.client.base_requestor import APIRequestor 6 | from vlmrun.types.abstract import VLMRunProtocol 7 | from vlmrun.client.types import ( 8 | FeedbackSubmitRequest, 9 | FeedbackListResponse, 10 | FeedbackSubmitResponse, 11 | ) 12 | 13 | 14 | class Feedback: 15 | """Feedback resource for VLM Run API.""" 16 | 17 | def __init__(self, client: "VLMRunProtocol") -> None: 18 | """Initialize Feedback resource with VLMRun instance. 19 | 20 | Args: 21 | client: VLM Run API instance 22 | """ 23 | self._client = client 24 | self._requestor = APIRequestor(client) 25 | 26 | def get( 27 | self, 28 | request_id: str, 29 | offset: int = 0, 30 | limit: int = 10, 31 | ) -> FeedbackListResponse: 32 | """Get feedback for a prediction request. 33 | 34 | Args: 35 | request_id: ID of the prediction request 36 | offset: Number of feedback items to skip 37 | limit: Maximum number of feedback items to return 38 | 39 | Returns: 40 | FeedbackListResponse: Response with list of feedback items 41 | """ 42 | response, status_code, headers = self._requestor.request( 43 | method="GET", 44 | url=f"feedback/{request_id}", 45 | params={"offset": offset, "limit": limit}, 46 | ) 47 | return FeedbackListResponse(**response) 48 | 49 | def submit( 50 | self, 51 | request_id: str, 52 | response: Optional[Dict[str, Any]] = None, 53 | notes: Optional[str] = None, 54 | ) -> FeedbackSubmitResponse: 55 | """Submit feedback for a prediction. 56 | 57 | Args: 58 | request_id: ID of the prediction request 59 | response: Optional feedback response data 60 | notes: Optional notes about the feedback 61 | 62 | Returns: 63 | FeedbackSubmitResponse: Response with submitted feedback 64 | """ 65 | if response is None and notes is None: 66 | raise ValueError( 67 | "`response` or `notes` parameter is required and cannot be None" 68 | ) 69 | 70 | feedback_data = FeedbackSubmitRequest( 71 | request_id=request_id, response=response, notes=notes 72 | ) 73 | 74 | response_data, status_code, headers = self._requestor.request( 75 | method="POST", 76 | url="feedback/submit", 77 | data=feedback_data.model_dump(exclude_none=True), 78 | ) 79 | return FeedbackSubmitResponse(**response_data) 80 | -------------------------------------------------------------------------------- /tests/cli/test_cli_predictions.py: -------------------------------------------------------------------------------- 1 | """Test predictions subcommand.""" 2 | 3 | from vlmrun.cli.cli import app 4 | 5 | 6 | def test_list_predictions(runner, mock_client, config_file): 7 | """Test list predictions command.""" 8 | result = runner.invoke(app, ["predictions", "list"]) 9 | assert result.exit_code == 0 10 | assert "prediction1" in result.stdout 11 | assert "running" in result.stdout 12 | assert "2024-01-01" in result.stdout 13 | 14 | 15 | def test_list_predictions_with_status_filter(runner, mock_client, config_file): 16 | """Test list predictions with status filter.""" 17 | # Test with running status (should match mock data) 18 | result = runner.invoke(app, ["predictions", "list", "--status", "running"]) 19 | assert result.exit_code == 0 20 | assert "prediction1" in result.stdout 21 | assert "running" in result.stdout 22 | 23 | result = runner.invoke(app, ["predictions", "list", "--status", "completed"]) 24 | assert result.exit_code == 0 25 | assert "No predictions found" in result.stdout or result.stdout.strip() == "" 26 | 27 | 28 | def test_get_prediction(runner, mock_client, config_file): 29 | """Test get prediction command.""" 30 | result = runner.invoke(app, ["predictions", "get", "prediction1"]) 31 | assert result.exit_code == 0 32 | assert "prediction1" in result.stdout 33 | assert "Status" in result.stdout 34 | assert "running" in result.stdout 35 | assert "Created" in result.stdout 36 | assert "2024-01-01" in result.stdout 37 | 38 | 39 | def test_get_prediction_with_wait(runner, mock_client, config_file): 40 | """Test get prediction command with wait flag.""" 41 | result = runner.invoke( 42 | app, ["predictions", "get", "prediction1", "--wait", "--timeout", "5"] 43 | ) 44 | assert result.exit_code == 0 45 | assert "prediction1" in result.stdout 46 | assert "completed" in result.stdout 47 | assert "Credits Used: 100" in result.stdout 48 | assert "Response" in result.stdout 49 | assert "test" in result.stdout 50 | 51 | 52 | def test_get_prediction_usage_display(runner, mock_client, config_file): 53 | """Test that prediction usage information is displayed correctly.""" 54 | result = runner.invoke(app, ["predictions", "get", "prediction1", "--wait"]) 55 | assert result.exit_code == 0 56 | assert "Credits Used" in result.stdout 57 | assert "100" in result.stdout 58 | 59 | 60 | def test_list_predictions_table_format(runner, mock_client, config_file): 61 | """Test that list output is formatted correctly.""" 62 | result = runner.invoke(app, ["predictions", "list"]) 63 | assert result.exit_code == 0 64 | assert "id" in result.stdout 65 | assert "reated_at" in result.stdout 66 | assert "status" in result.stdout 67 | -------------------------------------------------------------------------------- /vlmrun/client/executions.py: -------------------------------------------------------------------------------- 1 | """VLM Run API Executions resource.""" 2 | 3 | from __future__ import annotations 4 | 5 | import time 6 | from vlmrun.client.base_requestor import APIRequestor 7 | from vlmrun.types.abstract import VLMRunProtocol 8 | from vlmrun.client.types import ( 9 | AgentExecutionResponse, 10 | ) 11 | 12 | 13 | class Executions: 14 | """Executions resource for VLM Run API.""" 15 | 16 | def __init__(self, client: "VLMRunProtocol") -> None: 17 | """Initialize Executions resource with VLMRun instance. 18 | 19 | Args: 20 | client: VLM Run API instance 21 | """ 22 | self._client = client 23 | self._requestor = APIRequestor(client, timeout=120) 24 | 25 | def list(self, skip: int = 0, limit: int = 10) -> list[AgentExecutionResponse]: 26 | """List all executions. 27 | 28 | Args: 29 | skip: Number of items to skip 30 | limit: Maximum number of items to return 31 | 32 | Returns: 33 | List[AgentExecutionResponse]: List of execution objects 34 | """ 35 | response, status_code, headers = self._requestor.request( 36 | method="GET", 37 | url="agent/executions", 38 | params={"skip": skip, "limit": limit}, 39 | ) 40 | return [AgentExecutionResponse(**execution) for execution in response] 41 | 42 | def get(self, id: str) -> AgentExecutionResponse: 43 | """Get execution by ID. 44 | 45 | Args: 46 | id: ID of execution to retrieve 47 | 48 | Returns: 49 | AgentExecutionResponse: Execution metadata 50 | """ 51 | response, status_code, headers = self._requestor.request( 52 | method="GET", 53 | url=f"agent/executions/{id}", 54 | ) 55 | if not isinstance(response, dict): 56 | raise TypeError("Expected dict response") 57 | return AgentExecutionResponse(**response) 58 | 59 | def wait( 60 | self, id: str, timeout: int = 300, sleep: int = 5 61 | ) -> AgentExecutionResponse: 62 | """Wait for execution to complete. 63 | 64 | Args: 65 | id: ID of execution to wait for 66 | timeout: Maximum number of seconds to wait 67 | sleep: Time to wait between checks in seconds (default: 5) 68 | 69 | Returns: 70 | AgentExecutionResponse: Completed execution 71 | 72 | Raises: 73 | TimeoutError: If execution does not complete within timeout 74 | """ 75 | start_time = time.time() 76 | while True: 77 | response: AgentExecutionResponse = self.get(id) 78 | if response.status == "completed": 79 | return response 80 | 81 | elapsed = time.time() - start_time 82 | if elapsed >= timeout: 83 | raise TimeoutError( 84 | f"Execution {id} did not complete within {timeout} seconds. Last status: {response.status}" 85 | ) 86 | time.sleep(min(sleep, timeout - elapsed)) 87 | -------------------------------------------------------------------------------- /tests/cli/test_cli_config.py: -------------------------------------------------------------------------------- 1 | """Test config subcommand.""" 2 | 3 | import pytest 4 | from unittest.mock import patch 5 | from vlmrun.cli.cli import app 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def mock_health_check(): 10 | """Mock the health check request in VLMRun client initialization.""" 11 | with patch("vlmrun.client.client.APIRequestor.request") as mock_request: 12 | mock_request.return_value = (None, 200, {}) 13 | yield mock_request 14 | 15 | 16 | def test_show_empty_config(runner, config_file): 17 | """Test showing config when no values are set.""" 18 | result = runner.invoke(app, ["config", "show"]) 19 | assert result.exit_code == 0 20 | assert "No configuration values set" in result.stdout 21 | 22 | 23 | def test_set_and_show_config(runner, config_file): 24 | """Test setting and showing config values.""" 25 | result = runner.invoke(app, ["config", "set", "--api-key", "test-key"]) 26 | assert result.exit_code == 0 27 | assert "Set API key" in result.stdout 28 | 29 | result = runner.invoke(app, ["config", "set", "--base-url", "https://test.vlm.run"]) 30 | assert result.exit_code == 0 31 | assert "Set base URL" in result.stdout 32 | 33 | result = runner.invoke(app, ["config", "show"]) 34 | assert result.exit_code == 0 35 | assert "api_key:" in result.stdout 36 | assert "base_url: https://test.vlm.run" in result.stdout 37 | 38 | 39 | def test_unset_config(runner, config_file): 40 | """Test unsetting config values.""" 41 | runner.invoke(app, ["config", "set", "--api-key", "test-key"]) 42 | runner.invoke(app, ["config", "set", "--base-url", "https://test.vlm.run"]) 43 | 44 | result = runner.invoke(app, ["config", "unset", "--api-key"]) 45 | assert result.exit_code == 0 46 | assert "Removed API key" in result.stdout 47 | 48 | result = runner.invoke(app, ["config", "show"]) 49 | assert result.exit_code == 0 50 | assert "api_key" not in result.stdout 51 | assert "base_url: https://test.vlm.run" in result.stdout 52 | 53 | 54 | def test_set_no_values(runner, config_file): 55 | """Test set command with no values.""" 56 | result = runner.invoke(app, ["config", "set"]) 57 | assert result.exit_code == 0 58 | assert "No values provided to set" in result.stdout 59 | 60 | 61 | def test_unset_no_values(runner, config_file): 62 | """Test unset command with no values.""" 63 | result = runner.invoke(app, ["config", "unset"]) 64 | assert result.exit_code == 0 65 | assert "No values specified to unset" in result.stdout 66 | 67 | 68 | def test_config_file_permissions(runner, tmp_path, monkeypatch): 69 | """Test handling of permission errors.""" 70 | config_dir = tmp_path / ".vlmrun" 71 | config_dir.mkdir() 72 | config_path = config_dir / "config.toml" 73 | config_path.touch(mode=0o444) 74 | monkeypatch.setattr("vlmrun.cli._cli.config.CONFIG_FILE", config_path) 75 | 76 | result = runner.invoke(app, ["config", "set", "--api-key", "test-key"]) 77 | assert result.exit_code == 1 78 | assert "Error" in result.stdout 79 | 80 | 81 | def test_invalid_toml_handling(runner, config_file): 82 | """Test handling of invalid TOML format.""" 83 | config_file.write_text("invalid [ toml") 84 | 85 | result = runner.invoke(app, ["config", "show"]) 86 | assert result.exit_code == 0 87 | assert "No configuration values set" in result.stdout 88 | -------------------------------------------------------------------------------- /vlmrun/cli/_cli/hub.py: -------------------------------------------------------------------------------- 1 | """Hub API commands.""" 2 | 3 | import typer 4 | import json 5 | from rich.table import Table 6 | from rich.console import Console 7 | from rich.syntax import Syntax 8 | from rich.panel import Panel 9 | from vlmrun.client import VLMRun 10 | from vlmrun.client.types import HubDomainInfo, HubSchemaResponse 11 | from typing import List 12 | 13 | app = typer.Typer( 14 | help="Hub operations", 15 | add_completion=False, 16 | no_args_is_help=True, 17 | ) 18 | 19 | 20 | @app.command() 21 | def version(ctx: typer.Context) -> None: 22 | """Get hub version.""" 23 | client: VLMRun = ctx.obj 24 | info = client.hub.info() 25 | console = Console() 26 | console.print(f"Hub version: {info.version}", style="white") 27 | console.print( 28 | "\nVisit https://github.com/vlm-run/vlmrun-hub for more information", 29 | style="white", 30 | ) 31 | 32 | 33 | @app.command("list") 34 | def list_domains( 35 | ctx: typer.Context, 36 | domain: str = typer.Option( 37 | None, help="Filter domains (e.g. 'document' or 'document.invoice')" 38 | ), 39 | ) -> None: 40 | """List hub domains.""" 41 | client: VLMRun = ctx.obj 42 | domains: List[HubDomainInfo] = client.hub.list_domains() 43 | 44 | if domain: 45 | domains = [d for d in domains if domain in d.domain] 46 | 47 | console = Console() 48 | table = Table( 49 | show_header=True, 50 | header_style="white", 51 | title="Available Domains", 52 | title_style="white", 53 | border_style="white", 54 | expand=True, 55 | ) 56 | 57 | table.add_column("Category") 58 | table.add_column("Domain") 59 | 60 | domain_groups = {} 61 | for domain in domains: 62 | category = domain.domain.split(".")[0] 63 | if category not in domain_groups: 64 | domain_groups[category] = [] 65 | domain_groups[category].append(domain) 66 | 67 | for category in sorted(domain_groups.keys()): 68 | first_in_category = True 69 | for domain in sorted(domain_groups[category], key=lambda x: x.domain): 70 | if first_in_category: 71 | table.add_row(category, domain.domain, style="white") 72 | first_in_category = False 73 | else: 74 | table.add_row("", domain.domain, style="white") 75 | if category != sorted(domain_groups.keys())[-1]: 76 | table.add_row("", "", style="white") 77 | 78 | console.print(table) 79 | 80 | 81 | @app.command() 82 | def schema( 83 | ctx: typer.Context, 84 | domain: str = typer.Argument( 85 | ..., help="Domain identifier (e.g. 'document.invoice')" 86 | ), 87 | ) -> None: 88 | """Get JSON schema for a domain.""" 89 | client: VLMRun = ctx.obj 90 | response: HubSchemaResponse = client.hub.get_schema(domain) 91 | 92 | console = Console() 93 | 94 | console.print("\nSchema Information:", style="white") 95 | meta_table = Table(show_header=False, box=None) 96 | meta_table.add_row("Domain:", domain, style="white") 97 | meta_table.add_row("Version:", response.schema_version, style="white") 98 | console.print(meta_table) 99 | 100 | console.print("\nJSON Schema:", style="white") 101 | json_str = json.dumps(response.json_schema, indent=2) 102 | syntax = Syntax(json_str, "json", theme="default", line_numbers=True) 103 | console.print(Panel(syntax, expand=False, border_style="white")) 104 | -------------------------------------------------------------------------------- /tests/common/test_image.py: -------------------------------------------------------------------------------- 1 | """Tests for image utilities.""" 2 | 3 | import base64 4 | from io import BytesIO 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | import pytest 9 | import requests 10 | from PIL import Image 11 | 12 | from vlmrun.common.image import encode_image 13 | from vlmrun.common.utils import download_image 14 | 15 | 16 | def _validate_base64_image(data: Union[str, bytes], expected_format: str) -> None: 17 | """Helper to validate base64 encoded image data.""" 18 | # Convert to string if needed 19 | data_str = data.decode() if isinstance(data, bytes) else str(data) 20 | 21 | # Validate prefix 22 | prefix = f"data:image/{expected_format.lower()};base64," 23 | assert data_str.startswith(prefix), f"Data does not start with {prefix}" 24 | 25 | # Extract base64 content 26 | base64_str = data_str.split(",", 1)[1] 27 | 28 | # Validate base64 content 29 | try: 30 | base64.b64decode(base64_str.encode()) 31 | except Exception as e: 32 | raise AssertionError(f"Invalid base64 content: {e}") 33 | 34 | 35 | @pytest.fixture 36 | def sample_image(tmp_path) -> Image.Image: 37 | """Create a sample image for testing.""" 38 | img = Image.new("RGB", (100, 100), color="red") 39 | img_path = tmp_path / "test_image.png" 40 | img.save(img_path) 41 | return img 42 | 43 | 44 | @pytest.fixture 45 | def sample_image_path(sample_image, tmp_path) -> Path: 46 | """Save sample image to a file and return the path.""" 47 | path = tmp_path / "test_image.png" 48 | sample_image.save(path) 49 | return path 50 | 51 | 52 | def test_encode_image_from_pil(sample_image): 53 | """Test encoding a PIL Image.""" 54 | # Test PNG encoding 55 | png_data = encode_image(sample_image, format="PNG") 56 | _validate_base64_image(png_data, "PNG") 57 | 58 | # Test JPEG encoding 59 | jpeg_data = encode_image(sample_image, format="JPEG") 60 | _validate_base64_image(jpeg_data, "JPEG") 61 | 62 | # Test binary format 63 | binary_data = encode_image(sample_image, format="binary") 64 | assert isinstance(binary_data, bytes) 65 | assert Image.open(BytesIO(binary_data)).mode == "RGB" 66 | 67 | 68 | def test_encode_image_from_path(sample_image_path): 69 | """Test encoding an image from a file path.""" 70 | # Test with string path 71 | str_path_data = encode_image(str(sample_image_path)) 72 | _validate_base64_image(str_path_data, "PNG") 73 | 74 | # Test with Path object 75 | path_data = encode_image(sample_image_path) 76 | _validate_base64_image(path_data, "PNG") 77 | 78 | # Verify both methods produce same result 79 | assert str_path_data == path_data 80 | 81 | 82 | def test_encode_image_invalid(): 83 | """Test encoding with invalid inputs.""" 84 | with pytest.raises(FileNotFoundError): 85 | encode_image("nonexistent.jpg") # Invalid path 86 | 87 | 88 | def test_download_image(): 89 | """Test downloading an image from URL.""" 90 | # Use a reliable test image URL 91 | url = "https://raw.githubusercontent.com/python-pillow/Pillow/main/Tests/images/hopper.png" 92 | 93 | # Download and verify image 94 | img = download_image(url) 95 | assert isinstance(img, Image.Image) 96 | assert img.mode == "RGB" 97 | assert img.size[0] > 0 and img.size[1] > 0 98 | 99 | 100 | def test_download_image_invalid(): 101 | """Test downloading with invalid URL.""" 102 | with pytest.raises(requests.exceptions.RequestException): 103 | download_image("https://nonexistent.example.com/image.jpg") 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /vlmrun/cli/_cli/fine_tuning.py: -------------------------------------------------------------------------------- 1 | """Fine-tuning API commands.""" 2 | 3 | import typer 4 | from typing import List 5 | from rich.table import Table 6 | from rich.console import Console 7 | from rich import print as rprint 8 | 9 | from vlmrun.client import VLMRun 10 | from vlmrun.client.types import FinetuningResponse, FinetuningProvisionResponse 11 | 12 | app = typer.Typer( 13 | help="Fine-tuning operations", 14 | add_completion=False, 15 | no_args_is_help=True, 16 | ) 17 | 18 | 19 | @app.command() 20 | def create( 21 | ctx: typer.Context, 22 | model: str = typer.Argument(..., help="Base model name"), 23 | training_file: str = typer.Argument(..., help="Training file ID or URL"), 24 | validation_file: str = typer.Option(None, help="Validation file ID or URL"), 25 | num_epochs: int = typer.Option(1, help="Number of epochs"), 26 | batch_size: int = typer.Option(1, help="Batch size"), 27 | learning_rate: float = typer.Option(2e-4, help="Learning rate"), 28 | suffix: str = typer.Option(None, help="Suffix for the fine-tuned model"), 29 | wandb_project_name: str = typer.Option(None, help="Weights & Biases project name"), 30 | wandb_api_key: str = typer.Option(None, help="Weights & Biases API key"), 31 | wandb_base_url: str = typer.Option(None, help="Weights & Biases base URL"), 32 | ) -> None: 33 | """Create a fine-tuning job.""" 34 | client: VLMRun = ctx.obj 35 | result: FinetuningResponse = client.fine_tuning.create( 36 | model=model, 37 | suffix=suffix, 38 | training_file=training_file, 39 | validation_file=validation_file, 40 | num_epochs=num_epochs, 41 | batch_size=batch_size, 42 | learning_rate=learning_rate, 43 | wandb_project_name=wandb_project_name, 44 | wandb_api_key=wandb_api_key, 45 | wandb_base_url=wandb_base_url, 46 | ) 47 | rprint(f"Created fine-tuning job with ID: {result.id}") 48 | 49 | 50 | @app.command() 51 | def list(ctx: typer.Context) -> None: 52 | """List all fine-tuning jobs.""" 53 | client: VLMRun = ctx.obj 54 | jobs: List[FinetuningResponse] = client.fine_tuning.list() 55 | console = Console() 56 | table = Table(show_header=True, header_style="bold") 57 | table.add_column("id", min_width=40) 58 | table.add_column("model") 59 | table.add_column("status") 60 | table.add_column("created_at") 61 | table.add_column("completed_at") 62 | table.add_column("wandb_url") 63 | for job in jobs: 64 | table.add_row( 65 | job.id, 66 | job.model, 67 | job.status, 68 | job.created_at.strftime("%Y-%m-%d %H:%M:%S"), 69 | job.completed_at.strftime("%Y-%m-%d %H:%M:%S") if job.completed_at else "", 70 | job.wandb_url, 71 | ) 72 | 73 | console.print(table) 74 | 75 | 76 | @app.command() 77 | def provision( 78 | ctx: typer.Context, 79 | model: str = typer.Argument(..., help="Model to provision"), 80 | duration: int = typer.Option(10 * 60, help="Duration for the provisioned model"), 81 | concurrency: int = typer.Option(1, help="Concurrency for the provisioned model"), 82 | ) -> None: 83 | """Provision a fine-tuning model.""" 84 | client: VLMRun = ctx.obj 85 | result: FinetuningProvisionResponse = client.fine_tuning.provision( 86 | model=model, duration=duration, concurrency=concurrency 87 | ) 88 | rprint(f"Provisioned fine-tuning model\n{result}") 89 | 90 | 91 | @app.command() 92 | def get( 93 | ctx: typer.Context, 94 | job_id: str = typer.Argument(..., help="ID of the fine-tuning job"), 95 | ) -> None: 96 | """Get fine-tuning job details.""" 97 | client: VLMRun = ctx.obj 98 | job: FinetuningResponse = client.fine_tuning.get(job_id) 99 | rprint(job) 100 | 101 | 102 | @app.command() 103 | def cancel( 104 | ctx: typer.Context, 105 | job_id: str = typer.Argument(..., help="ID of the fine-tuning job to cancel"), 106 | ) -> None: 107 | """Cancel a fine-tuning job.""" 108 | client: VLMRun = ctx.obj 109 | client.fine_tuning.cancel(job_id) 110 | rprint(f"Cancelled fine-tuning job {job_id}") 111 | -------------------------------------------------------------------------------- /vlmrun/cli/_cli/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | import sys 4 | import typer 5 | from rich import print as rprint 6 | from pydantic import Field 7 | from pydantic.dataclasses import dataclass 8 | from dataclasses import asdict 9 | 10 | if sys.version_info >= (3, 11): 11 | import tomllib 12 | else: 13 | import tomli as tomllib 14 | 15 | CONFIG_FILE = Path.home() / ".vlmrun" / "config.toml" 16 | 17 | 18 | @dataclass 19 | class Config: 20 | """VLM Run configuration.""" 21 | 22 | api_key: Optional[str] = Field(default=None, description="VLM Run API key") 23 | base_url: Optional[str] = Field(default=None, description="VLM Run API base URL") 24 | 25 | 26 | def get_config() -> Config: 27 | """Load configuration from ~/.vlmrun/config.toml""" 28 | if not CONFIG_FILE.exists(): 29 | return Config() 30 | 31 | try: 32 | with CONFIG_FILE.open("rb") as f: 33 | data = tomllib.load(f) 34 | return Config(**data) 35 | except (PermissionError, OSError) as e: 36 | rprint(f"[red]Error reading config file:[/] {e}") 37 | return Config() 38 | except tomllib.TOMLDecodeError: 39 | rprint("[red]Error:[/] Invalid TOML format in config file") 40 | return Config() 41 | 42 | 43 | def save_config(config: Config) -> None: 44 | """Save configuration to ~/.vlmrun/config.toml""" 45 | config_dict = {k: v for k, v in asdict(config).items() if v is not None} 46 | 47 | try: 48 | CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True) 49 | 50 | with CONFIG_FILE.open("w") as f: 51 | for key, value in config_dict.items(): 52 | escaped_value = str(value).replace('"', '\\"') 53 | f.write(f'{key} = "{escaped_value}"\n') 54 | except (PermissionError, OSError) as e: 55 | rprint(f"[red]Error saving config file:[/] {e}") 56 | raise typer.Exit(1) 57 | 58 | 59 | app = typer.Typer( 60 | help="Configuration operations", 61 | add_completion=False, 62 | no_args_is_help=True, 63 | ) 64 | 65 | 66 | @app.command(name="show") 67 | def show_config() -> None: 68 | """Show current configuration.""" 69 | config = get_config() 70 | 71 | if not any(getattr(config, field) for field in asdict(config).keys()): 72 | rprint("[yellow]No configuration values set[/]") 73 | return 74 | 75 | display_config = asdict(config) 76 | if config.api_key: 77 | key = config.api_key 78 | display_config["api_key"] = f"{key[:4]}...{key[-4:]}" 79 | 80 | for key, value in display_config.items(): 81 | if value is not None: 82 | rprint(f"{key}: {value}") 83 | 84 | 85 | @app.command(name="set") 86 | def set_value( 87 | api_key: Optional[str] = typer.Option(None, "--api-key", help="VLM Run API key"), 88 | base_url: Optional[str] = typer.Option( 89 | None, "--base-url", help="VLM Run API base URL" 90 | ), 91 | ) -> None: 92 | """Set configuration values.""" 93 | config = get_config() 94 | 95 | if api_key: 96 | config.api_key = api_key 97 | rprint("[green]✓[/] Set API key") 98 | 99 | if base_url: 100 | config.base_url = base_url 101 | rprint(f"[green]✓[/] Set base URL: {base_url}") 102 | 103 | if not (api_key or base_url): 104 | rprint("[yellow]No values provided to set[/]") 105 | return 106 | 107 | save_config(config) 108 | 109 | 110 | @app.command() 111 | def unset( 112 | api_key: bool = typer.Option(False, "--api-key", help="Unset API key"), 113 | base_url: bool = typer.Option(False, "--base-url", help="Unset base URL"), 114 | ) -> None: 115 | """Remove configuration values.""" 116 | config = get_config() 117 | 118 | if api_key and config.api_key: 119 | config.api_key = None 120 | rprint("[green]✓[/] Removed API key") 121 | 122 | if base_url and config.base_url: 123 | config.base_url = None 124 | rprint("[green]✓[/] Removed base URL") 125 | 126 | if not (api_key or base_url): 127 | rprint("[yellow]No values specified to unset[/]") 128 | return 129 | 130 | save_config(config) 131 | -------------------------------------------------------------------------------- /vlmrun/common/image.py: -------------------------------------------------------------------------------- 1 | """Image utilities for VLMRun.""" 2 | 3 | from base64 import b64encode 4 | from io import BytesIO 5 | from pathlib import Path 6 | from typing import Literal, Union 7 | 8 | from PIL import Image, ImageOps 9 | 10 | from vlmrun.constants import SUPPORTED_VIDEO_FILETYPES 11 | 12 | 13 | def _open_image_with_exif(path: Union[str, Path]) -> Image.Image: 14 | """Open an image and apply EXIF orientation if available. 15 | 16 | Args: 17 | path: Path to the image file 18 | 19 | Returns: 20 | PIL Image with EXIF orientation applied and converted to RGB 21 | """ 22 | image = Image.open(str(path)) 23 | try: 24 | image = ImageOps.exif_transpose(image) 25 | except Exception: 26 | pass 27 | return image.convert("RGB") 28 | 29 | 30 | def encode_video(path: Union[Path, str]) -> str: 31 | """Convert a video file to a base64 string with data URI prefix. 32 | 33 | Args: 34 | path: Path to video file 35 | 36 | Returns: 37 | Base64 encoded string with data URI prefix 38 | 39 | Raises: 40 | FileNotFoundError: If video file doesn't exist 41 | IOError: If file is not a video or too large 42 | """ 43 | if not isinstance(path, Path): 44 | path = Path(path) 45 | if not path.exists(): 46 | raise FileNotFoundError(f"File not found {path}") 47 | if path.suffix.lower() not in SUPPORTED_VIDEO_FILETYPES: 48 | raise IOError(f"File is not a video: {path}") 49 | if path.stat().st_size > 50 * 1024 * 1024: 50 | raise IOError( 51 | f"File is too large, only videos up to 50MB are supported. [size={path.stat().st_size/1024/1024:.2f}MB]" 52 | ) 53 | 54 | # Read the video file 55 | with open(path, "rb") as f: 56 | video_bytes = f.read() 57 | video_b64 = b64encode(video_bytes).decode() 58 | # Get video mime type from extension 59 | mime_type = f"video/{path.suffix.lower()[1:]}" 60 | return f"data:{mime_type};base64,{video_b64}" 61 | 62 | 63 | def encode_image( 64 | image: Union[Image.Image, str, Path], 65 | format: Literal["PNG", "JPEG", "binary"] = "PNG", 66 | quality: int = 90, 67 | ) -> Union[str, bytes]: 68 | """Convert an image to a base64 string or binary format. 69 | 70 | Args: 71 | image: PIL Image, path to image, or Path object 72 | format: Output format ("PNG", "JPEG", or "binary") 73 | quality: JPEG quality (1-100, default 90). Only used for JPEG format. 74 | 75 | Returns: 76 | Base64 encoded string or binary bytes 77 | 78 | Raises: 79 | FileNotFoundError: If image path doesn't exist 80 | ValueError: If image type is invalid or quality is out of range 81 | """ 82 | if not (1 <= quality <= 100): 83 | raise ValueError(f"Quality must be between 1 and 100, got {quality}") 84 | 85 | if isinstance(image, (str, Path)): 86 | if not Path(image).exists(): 87 | raise FileNotFoundError(f"File not found {image}") 88 | image = _open_image_with_exif(str(image)) 89 | elif isinstance(image, Image.Image): 90 | image = image.convert("RGB") 91 | else: 92 | raise ValueError(f"Invalid image type: {type(image)}") 93 | 94 | buffered = BytesIO() 95 | if format == "binary": 96 | image.save(buffered, format="PNG") 97 | buffered.seek(0) 98 | return buffered.getvalue() 99 | 100 | image_format = image.format or format 101 | if image_format.upper() not in ["PNG", "JPEG"]: 102 | raise ValueError(f"Unsupported format: {image_format}") 103 | 104 | save_params = { 105 | "format": image_format, 106 | **( 107 | {"quality": quality, "subsampling": 0} 108 | if image_format.upper() == "JPEG" 109 | else {} 110 | ), 111 | } 112 | 113 | try: 114 | image.save(buffered, **save_params) 115 | img_str = b64encode(buffered.getvalue()).decode() 116 | return f"data:image/{image_format.lower()};base64,{img_str}" 117 | except Exception as e: 118 | raise ValueError(f"Failed to save image in {image_format} format") from e 119 | -------------------------------------------------------------------------------- /vlmrun/client/datasets.py: -------------------------------------------------------------------------------- 1 | """VLM Run API Dataset resource.""" 2 | 3 | from __future__ import annotations 4 | 5 | from loguru import logger 6 | from pathlib import Path 7 | from typing import List, Literal, Optional 8 | from vlmrun.client.base_requestor import APIRequestor 9 | from vlmrun.types.abstract import VLMRunProtocol 10 | from vlmrun.client.types import DatasetResponse, FileResponse 11 | from vlmrun.common.utils import create_archive 12 | 13 | 14 | class Datasets: 15 | """Datasets resource for VLM Run API.""" 16 | 17 | def __init__(self, client: "VLMRunProtocol") -> None: 18 | """Initialize Datasets resource with VLMRun instance. 19 | 20 | Args: 21 | client: VLM Run API instance 22 | """ 23 | self._client = client 24 | self._requestor = APIRequestor(client, base_url=f"{client.base_url}") 25 | 26 | def create( 27 | self, 28 | domain: str, 29 | dataset_directory: Path, 30 | dataset_name: str, 31 | dataset_type: Literal["images", "videos", "documents"], 32 | wandb_base_url: Optional[str] = "https://api.wandb.ai", 33 | wandb_project_name: Optional[str] = None, 34 | wandb_api_key: Optional[str] = None, 35 | ) -> DatasetResponse: 36 | """Create a dataset from an uploaded file. 37 | 38 | Args: 39 | dataset_directory: Directory of files to create dataset from 40 | domain: Domain for the dataset 41 | dataset_name: Name of the dataset 42 | dataset_type: Type of dataset (images, videos, or documents) 43 | wandb_base_url: Base URL for Weights & Biases 44 | wandb_project_name: Project name for Weights & Biases 45 | wandb_api_key: API key for Weights & Biases 46 | 47 | Returns: 48 | DatasetResponse: Dataset creation response containing dataset_id and file_id 49 | """ 50 | if dataset_type not in ["images", "videos", "documents"]: 51 | raise ValueError("dataset_type must be one of: images, videos, documents") 52 | 53 | # Create tar.gz file 54 | tar_path: Path = create_archive(dataset_directory, dataset_name) 55 | logger.debug( 56 | f"Created tar.gz file [path={tar_path}, size={tar_path.stat().st_size / 1024 / 1024:.2f} MB]" 57 | ) 58 | 59 | # Upload tar.gz file 60 | file_response: FileResponse = self._client.files.upload( 61 | tar_path, purpose="datasets" 62 | ) 63 | logger.debug( 64 | f"Uploaded tar.gz file [path={tar_path}, file_id={file_response.id}, size={file_response.bytes / 1024 / 1024:.2f} MB]" 65 | ) 66 | 67 | # Create dataset 68 | response, status_code, headers = self._requestor.request( 69 | method="POST", 70 | url="datasets/create", 71 | data={ 72 | "file_id": file_response.id, 73 | "domain": domain, 74 | "dataset_name": dataset_name, 75 | "dataset_type": dataset_type, 76 | "wandb_base_url": wandb_base_url, 77 | "wandb_project_name": wandb_project_name, 78 | "wandb_api_key": wandb_api_key, 79 | }, 80 | ) 81 | if not isinstance(response, dict): 82 | raise TypeError("Expected dict response") 83 | return DatasetResponse(**response) 84 | 85 | def get(self, dataset_id: str) -> DatasetResponse: 86 | """Get dataset information by ID. 87 | 88 | Args: 89 | dataset_id: ID of the dataset to retrieve 90 | 91 | Returns: 92 | DatasetResponse: Dataset information including dataset_id and file_id 93 | """ 94 | response, status_code, headers = self._requestor.request( 95 | method="GET", 96 | url=f"datasets/{dataset_id}", 97 | ) 98 | if not isinstance(response, dict): 99 | raise TypeError("Expected dict response") 100 | return DatasetResponse(**response) 101 | 102 | def list(self, skip: int = 0, limit: int = 10) -> List[DatasetResponse]: 103 | """List all datasets.""" 104 | items, status_code, headers = self._requestor.request( 105 | method="GET", 106 | url="datasets", 107 | params={"skip": skip, "limit": limit}, 108 | ) 109 | if not isinstance(items, list): 110 | raise TypeError("Expected list response") 111 | return [DatasetResponse(**response) for response in items] 112 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Python Publish 2 | 3 | on: 4 | push: 5 | paths: 6 | - "vlmrun/version.py" 7 | branches: 8 | - main 9 | 10 | env: 11 | CACHE_NUMBER: 1 # increase to reset cache manually 12 | 13 | jobs: 14 | test: 15 | name: Test 16 | runs-on: ubuntu-latest 17 | timeout-minutes: 20 18 | environment: dev 19 | steps: 20 | - name: Checkout git repo 21 | uses: actions/checkout@v3 22 | 23 | - uses: actions/setup-python@v5 24 | with: 25 | python-version: "3.10" 26 | 27 | - uses: actions/cache@v4 28 | with: 29 | path: | 30 | ~/.cache/uv 31 | ~/.cache/pip 32 | key: uv-py310-${{ hashFiles('requirements/requirements*.txt') }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('MANIFEST.in') }}-${{ env.CACHE_NUMBER }} 33 | restore-keys: | 34 | uv-py310- 35 | 36 | - name: Install dependencies 37 | run: | 38 | pip install uv 39 | uv pip install --system -e '.[test,all]' 40 | 41 | - name: Quality Check 42 | uses: pre-commit/action@v3.0.1 43 | continue-on-error: true 44 | 45 | - name: Run tests 46 | run: | 47 | make test 48 | 49 | publish: 50 | name: Publish 51 | runs-on: ubuntu-latest 52 | timeout-minutes: 20 53 | environment: prod 54 | needs: test 55 | steps: 56 | - name: Checkout git repo 57 | uses: actions/checkout@v3 58 | with: 59 | fetch-depth: 0 60 | token: ${{ secrets.GH_TOKEN }} 61 | - uses: actions/setup-python@v5 62 | with: 63 | python-version: "3.10" 64 | 65 | - uses: actions/cache@v4 66 | with: 67 | path: | 68 | ~/.cache/uv 69 | ~/.cache/pip 70 | key: uv-${{ hashFiles('requirements/requirements*.txt') }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('MANIFEST.in') }}-${{ env.CACHE_NUMBER }} 71 | restore-keys: | 72 | uv- 73 | 74 | - name: Install dependencies 75 | run: | 76 | pip install uv 77 | uv pip install --system -e '.[test,build,all]' 78 | 79 | - name: Bump version 80 | if: success() 81 | run: | 82 | version=$(grep -oP '__version__ = "\K[^"]+' vlmrun/version.py) 83 | echo "Current version: ${version}" 84 | echo "VERSION=${version}" >> $GITHUB_ENV 85 | 86 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 87 | git config --local user.name "github-actions[bot]" 88 | 89 | git tag -a "v${version}" -m "Version ${version}" 90 | git push origin main 91 | git push origin "v${version}" 92 | 93 | - name: Build package 94 | run: | 95 | python -m build 96 | 97 | - name: Publish to PyPI 98 | uses: pypa/gh-action-pypi-publish@release/v1 99 | with: 100 | password: ${{ secrets.PYPI_TOKEN }} 101 | 102 | - name: Notify Slack on Success 103 | if: success() 104 | uses: slackapi/slack-github-action@v2.0.0 105 | with: 106 | webhook: ${{ secrets.SLACK_WEBHOOK_URL }} 107 | webhook-type: incoming-webhook 108 | payload: | 109 | { 110 | "blocks": [ 111 | { 112 | "type": "section", 113 | "text": { 114 | "type": "mrkdwn", 115 | "text": ":white_check_mark: :python: *vlmrun-python-sdk v${{ env.VERSION }} Published to PyPI*\n\n*Version:* `v${{ env.VERSION }}`\n*Commit:* <${{ github.server_url }}/${{ github.repository }}/commit/${{ github.sha }}|${{ github.sha }}>" 116 | } 117 | } 118 | ] 119 | } 120 | 121 | - name: Notify Slack on Failure 122 | if: failure() 123 | uses: slackapi/slack-github-action@v2.0.0 124 | with: 125 | webhook: ${{ secrets.SLACK_WEBHOOK_URL }} 126 | webhook-type: incoming-webhook 127 | payload: | 128 | { 129 | "blocks": [ 130 | { 131 | "type": "section", 132 | "text": { 133 | "type": "mrkdwn", 134 | "text": ":x: :python: *vlmrun-python-sdk Publish Failed*\n\n*Tag:* `${{ github.ref_name }}`\n*Workflow:* <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|View Run>" 135 | } 136 | } 137 | ] 138 | } 139 | -------------------------------------------------------------------------------- /vlmrun/cli/cli.py: -------------------------------------------------------------------------------- 1 | """Main CLI implementation for vlmrun.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Optional 6 | import os 7 | 8 | import typer 9 | from rich import print as rprint 10 | from rich.console import Console 11 | from rich.panel import Panel 12 | from rich.text import Text 13 | 14 | from vlmrun.client import VLMRun 15 | from vlmrun.cli._cli.files import app as files_app 16 | from vlmrun.cli._cli.fine_tuning import app as fine_tuning_app 17 | from vlmrun.cli._cli.models import app as models_app 18 | from vlmrun.cli._cli.generate import app as generate_app 19 | from vlmrun.cli._cli.hub import app as hub_app 20 | from vlmrun.cli._cli.datasets import app as dataset_app 21 | from vlmrun.cli._cli.predictions import app as predictions_app 22 | from vlmrun.cli._cli.config import app as config_app, get_config 23 | from vlmrun.constants import DEFAULT_BASE_URL 24 | 25 | app = typer.Typer( 26 | name="vlmrun", 27 | help="CLI for VLM Run (https://app.vlm.run)", 28 | add_completion=True, 29 | no_args_is_help=True, 30 | ) 31 | 32 | console = Console() 33 | 34 | 35 | def version_callback(value: bool) -> None: 36 | """Print version and exit.""" 37 | if value: 38 | from vlmrun import version 39 | 40 | rprint(f"vlmrun version: {version.__version__}") 41 | raise typer.Exit() 42 | 43 | 44 | def check_credentials( 45 | ctx: typer.Context, api_key: Optional[str], base_url: str 46 | ) -> None: 47 | """Check if API key is present and show helpful message if missing.""" 48 | config = get_config() 49 | api_key = api_key or config.api_key 50 | base_url = base_url or config.base_url 51 | 52 | if not api_key: 53 | console.print("\n[red bold]Error:[/] API key not found! 🔑\n") 54 | console.print( 55 | Panel( 56 | Text.from_markup( 57 | "To use the VLM Run CLI, you need to provide an API key. You can either:\n\n" 58 | "1. Set it in your config:\n" 59 | " [green]vlmrun config set --api-key 'your-api-key'[/]\n\n" 60 | "2. Set it as an environment variable:\n" 61 | " [green]export VLMRUN_API_KEY='your-api-key'[/]\n\n" 62 | "3. Pass it directly as an argument:\n" 63 | " [green]vlmrun --api-key 'your-api-key' ...[/]\n\n" 64 | "Get your API key at: [blue]https://app.vlm.run/dashboard[/]" 65 | ), 66 | title="API Key Required", 67 | border_style="red", 68 | ) 69 | ) 70 | raise typer.Exit(1) 71 | 72 | is_default_url = base_url == DEFAULT_BASE_URL and not os.getenv("VLMRUN_BASE_URL") 73 | if is_default_url and not ctx.meta.get("has_shown_base_url_notice"): 74 | console.print( 75 | f"[yellow]Note:[/] Using default API endpoint: [blue]{base_url}[/]\n" 76 | "To use a different endpoint, set [green]VLMRUN_BASE_URL[/] or use [green]--base-url[/]" 77 | ) 78 | ctx.meta["has_shown_base_url_notice"] = True 79 | 80 | 81 | @app.callback() 82 | def main( 83 | ctx: typer.Context, 84 | api_key: Optional[str] = typer.Option( 85 | None, 86 | "--api-key", 87 | envvar="VLMRUN_API_KEY", 88 | help="VLM Run API key. Can also be set via VLMRUN_API_KEY environment variable.", 89 | ), 90 | base_url: Optional[str] = typer.Option( 91 | None, 92 | "--base-url", 93 | envvar="VLMRUN_BASE_URL", 94 | help="VLM Run API base URL. Can also be set via VLMRUN_BASE_URL environment variable.", 95 | ), 96 | version: Optional[bool] = typer.Option( 97 | None, 98 | "--version", 99 | "-v", 100 | help="Show version and exit.", 101 | callback=version_callback, 102 | is_eager=True, 103 | ), 104 | debug: bool = typer.Option( 105 | False, 106 | "--debug", 107 | help="Enable debug mode.", 108 | ), 109 | ) -> None: 110 | """VLM Run CLI tool for interacting with the VLM Run API platform.""" 111 | if ctx.invoked_subcommand is not None: 112 | check_credentials(ctx, api_key, base_url) 113 | ctx.obj = VLMRun(api_key=api_key, base_url=base_url) 114 | 115 | 116 | # Add subcommands 117 | app.add_typer(files_app, name="files") 118 | app.add_typer(predictions_app, name="predictions") 119 | app.add_typer(fine_tuning_app, name="fine-tuning") 120 | app.add_typer(models_app, name="models") 121 | app.add_typer(generate_app, name="generate") 122 | app.add_typer(hub_app, name="hub") 123 | app.add_typer(dataset_app, name="datasets") 124 | app.add_typer(config_app, name="config") 125 | 126 | if __name__ == "__main__": 127 | app() 128 | -------------------------------------------------------------------------------- /vlmrun/cli/_cli/predictions.py: -------------------------------------------------------------------------------- 1 | """Predictions API commands.""" 2 | 3 | import typer 4 | from rich.table import Table 5 | from rich.console import Console 6 | from rich import print as rprint 7 | from datetime import datetime 8 | 9 | from vlmrun.client import VLMRun 10 | 11 | app = typer.Typer( 12 | help="Prediction operations", 13 | add_completion=False, 14 | no_args_is_help=True, 15 | ) 16 | 17 | 18 | @app.command() 19 | def list( 20 | ctx: typer.Context, 21 | skip: int = typer.Option(0, help="Skip the first N predictions"), 22 | limit: int = typer.Option(10, help="Limit the number of predictions to list"), 23 | status: str = typer.Option( 24 | None, 25 | help="Filter by status (enqueued, pending, running, completed, failed, paused)", 26 | ), 27 | since: str = typer.Option(None, help="Show predictions since date (YYYY-MM-DD)"), 28 | until: str = typer.Option(None, help="Show predictions until date (YYYY-MM-DD)"), 29 | ) -> None: 30 | """List predictions.""" 31 | client: VLMRun = ctx.obj 32 | predictions = client.predictions.list(skip=skip, limit=limit) 33 | 34 | if status: 35 | predictions = [p for p in predictions if p.status == status] 36 | 37 | if since: 38 | try: 39 | since_date = datetime.strptime(since, "%Y-%m-%d") 40 | predictions = [p for p in predictions if p.created_at >= since_date] 41 | except ValueError: 42 | rprint("[red]Error:[/] Invalid date format for --since. Use YYYY-MM-DD") 43 | raise typer.Exit(1) 44 | 45 | if until: 46 | try: 47 | until_date = datetime.strptime(until, "%Y-%m-%d") 48 | predictions = [p for p in predictions if p.created_at <= until_date] 49 | except ValueError: 50 | rprint("[red]Error:[/] Invalid date format for --until. Use YYYY-MM-DD") 51 | raise typer.Exit(1) 52 | 53 | if not predictions: 54 | rprint("[yellow]No predictions found[/]") 55 | return 56 | 57 | console = Console() 58 | table = Table(show_header=True, header_style="white", border_style="white") 59 | table.add_column("id", min_width=40) 60 | table.add_column("created_at") 61 | table.add_column("status") 62 | table.add_column("usage.elements_processed") 63 | table.add_column("usage.element_type") 64 | table.add_column("usage.credits_used") 65 | for prediction in predictions: 66 | table.add_row( 67 | prediction.id, 68 | prediction.created_at.strftime("%Y-%m-%d %H:%M:%S"), 69 | str(prediction.status), 70 | str(prediction.usage.elements_processed), 71 | str(prediction.usage.element_type), 72 | str(prediction.usage.credits_used), 73 | style="white", 74 | ) 75 | 76 | console.print(table) 77 | 78 | 79 | @app.command() 80 | def get( 81 | ctx: typer.Context, 82 | prediction_id: str = typer.Argument(..., help="ID of the prediction to retrieve"), 83 | wait: bool = typer.Option(False, help="Wait for prediction to complete"), 84 | timeout: int = typer.Option(60, help="Timeout in seconds when waiting"), 85 | ) -> None: 86 | """Get prediction details with optional wait functionality.""" 87 | client: VLMRun = ctx.obj 88 | 89 | if wait: 90 | prediction = client.predictions.wait(prediction_id, timeout=timeout) 91 | else: 92 | prediction = client.predictions.get(prediction_id) 93 | 94 | console = Console() 95 | 96 | console.print("\nPrediction Details:\n", style="white") 97 | 98 | details = { 99 | "ID": prediction.id, 100 | "Status": prediction.status, 101 | } 102 | 103 | if prediction.created_at: 104 | details["Created"] = prediction.created_at.strftime("%Y-%m-%d %H:%M:%S") 105 | if prediction.completed_at: 106 | details["Completed"] = prediction.completed_at.strftime("%Y-%m-%d %H:%M:%S") 107 | 108 | if prediction.usage: 109 | if prediction.usage.elements_processed: 110 | details["Elements Processed"] = prediction.usage.elements_processed 111 | if prediction.usage.element_type: 112 | details["Element Type"] = prediction.usage.element_type 113 | if prediction.usage.credits_used: 114 | details["Credits Used"] = prediction.usage.credits_used 115 | 116 | for key, value in details.items(): 117 | console.print(f"{key}: {value}", style="white") 118 | 119 | if prediction.response: 120 | console.print("\nResponse:", style="white") 121 | console.print(prediction.response, style="white") 122 | 123 | if prediction.status == "failed" and prediction.error: 124 | console.print("\nError:", style="white") 125 | console.print(prediction.error, style="white") 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | VLM Run Logo
4 |

5 |

VLM Run Python SDK

6 |

Website | Platform | Docs | Blog | Discord 7 |

8 |

9 | PyPI Version 10 | PyPI Version 11 | PyPI Downloads
12 | License 13 | Discord 14 | Twitter Follow 15 |

16 |
17 | 18 | The [VLM Run Python SDK](https://pypi.org/project/vlmrun/) is the official Python SDK for [VLM Run API platform](https://docs.vlm.run), providing a convenient way to interact with our REST APIs. 19 | 20 | 21 | ## 🚀 Getting Started 22 | 23 | ### Installation 24 | 25 | ```bash 26 | pip install vlmrun 27 | ``` 28 | 29 | ### Installation with Optional Features 30 | 31 | The package provides optional features that can be installed based on your needs: 32 | 33 | - Video processing features (numpy, opencv-python): 34 | ```bash 35 | pip install "vlmrun[video]" 36 | ``` 37 | 38 | - Document processing features (pypdfium2): 39 | ```bash 40 | pip install "vlmrun[doc]" 41 | ``` 42 | 43 | - OpenAI SDK integration (for chat completions API): 44 | ```bash 45 | pip install "vlmrun[openai]" 46 | ``` 47 | 48 | - All optional features: 49 | ```bash 50 | pip install "vlmrun[all]" 51 | ``` 52 | 53 | ### Basic Usage 54 | 55 | ```python 56 | from PIL import Image 57 | from vlmrun.client import VLMRun 58 | from vlmrun.common.utils import remote_image 59 | 60 | # Initialize the client 61 | client = VLMRun(api_key="") 62 | 63 | # Process an image using local file or remote URL 64 | image: Image.Image = remote_image("https://storage.googleapis.com/vlm-data-public-prod/hub/examples/document.invoice/invoice_1.jpg") 65 | response = client.image.generate( 66 | images=[image], 67 | domain="document.invoice" 68 | ) 69 | print(response) 70 | 71 | # Or process an image directly from URL 72 | response = client.image.generate( 73 | urls=["https://storage.googleapis.com/vlm-data-public-prod/hub/examples/document.invoice/invoice_1.jpg"], 74 | domain="document.invoice" 75 | ) 76 | print(response) 77 | ``` 78 | 79 | ### OpenAI-Compatible Chat Completions 80 | 81 | The VLM Run SDK provides OpenAI-compatible chat completions through the agent endpoint. This allows you to use the familiar OpenAI API with VLM Run's powerful vision-language models. 82 | 83 | ```python 84 | from vlmrun.client import VLMRun 85 | 86 | client = VLMRun( 87 | api_key="your-key", 88 | base_url="https://agent.vlm.run/v1" 89 | ) 90 | 91 | response = client.agent.completions.create( 92 | model="vlmrun-orion-1", 93 | messages=[ 94 | {"role": "user", "content": "Hello!"} 95 | ] 96 | ) 97 | print(response.choices[0].message.content) 98 | ``` 99 | 100 | For async support: 101 | 102 | ```python 103 | import asyncio 104 | from vlmrun.client import VLMRun 105 | 106 | client = VLMRun(api_key="your-key", base_url="https://agent.vlm.run/v1") 107 | 108 | async def main(): 109 | response = await client.agent.async_completions.create( 110 | model="vlmrun-orion-1", 111 | messages=[{"role": "user", "content": "Hello!"}] 112 | ) 113 | print(response.choices[0].message.content) 114 | 115 | asyncio.run(main()) 116 | ``` 117 | 118 | **Installation**: Install with OpenAI support using `pip install vlmrun[openai]` 119 | 120 | ## 🔗 Quick Links 121 | 122 | * 💬 Need help? Email us at [support@vlm.run](mailto:support@vlm.run) or join our [Discord](https://discord.gg/AMApC2UzVY) 123 | * 📚 Check out our [Documentation](https://docs.vlm.run/) 124 | * 📣 Follow us on [Twitter](https://x.com/vlmrun) and [LinkedIn](https://www.linkedin.com/company/vlm-run) 125 | -------------------------------------------------------------------------------- /vlmrun/client/artifacts.py: -------------------------------------------------------------------------------- 1 | """VLM Run API Artifacts resource.""" 2 | 3 | from __future__ import annotations 4 | 5 | import io 6 | from email.parser import HeaderParser 7 | from pathlib import Path 8 | from typing import TYPE_CHECKING, Dict, Union 9 | 10 | from PIL import Image 11 | from pydantic import AnyHttpUrl 12 | 13 | from vlmrun.client.base_requestor import APIRequestor 14 | from vlmrun.constants import VLMRUN_CACHE_DIR 15 | 16 | 17 | def _get_disposition_params(disposition: str) -> Dict[str, str]: 18 | """Parse Content-Disposition header and return parameters. 19 | 20 | Args: 21 | disposition: Content-Disposition header value 22 | 23 | Returns: 24 | Dictionary of parameters from the header 25 | """ 26 | parser = HeaderParser() 27 | msg = parser.parsestr(f"Content-Disposition: {disposition}") 28 | params = msg.get_params(header="Content-Disposition", unquote=True) 29 | if not params: 30 | return {} 31 | # First element is the main value (e.g., "attachment"), rest are params 32 | return dict(params[1:]) 33 | 34 | 35 | if TYPE_CHECKING: 36 | from vlmrun.types.abstract import VLMRunProtocol 37 | 38 | 39 | class Artifacts: 40 | """Artifacts resource for VLM Run API.""" 41 | 42 | def __init__(self, client: "VLMRunProtocol") -> None: 43 | """Initialize Artifacts resource with VLMRun instance. 44 | 45 | Args: 46 | client: VLM Run API instance 47 | """ 48 | self._client = client 49 | self._requestor = APIRequestor(client) 50 | 51 | def get( 52 | self, session_id: str, object_id: str, raw_response: bool = False 53 | ) -> Union[bytes, Image.Image, AnyHttpUrl, Path]: 54 | """Get an artifact by session ID and object ID. 55 | 56 | Args: 57 | session_id: Session ID for the artifact 58 | object_id: Object ID for the artifact 59 | raw_response: Whether to return the raw response or not 60 | 61 | Returns: 62 | bytes: The artifact content if raw_response is True, otherwise 63 | converted to the appropriate type based on the content type 64 | """ 65 | response, status_code, headers = self._requestor.request( 66 | method="GET", 67 | url=f"artifacts/{session_id}/{object_id}", 68 | raw_response=True, 69 | ) 70 | 71 | if not isinstance(response, bytes): 72 | raise TypeError("Expected bytes response") 73 | 74 | # If raw response is requested, return the raw response as bytes 75 | if raw_response: 76 | return response 77 | 78 | # Otherwise, return the appropriate type based on the content type 79 | obj_type, _obj_id = object_id.split("_") 80 | if len(_obj_id) != 6: 81 | raise ValueError( 82 | f"Invalid object ID: {object_id}, expected format: _<6-digit-hex-string>" 83 | ) 84 | 85 | if obj_type == "img": 86 | assert ( 87 | headers["Content-Type"] == "image/jpeg" 88 | ), f"Expected image/jpeg, got {headers['Content-Type']}" 89 | return Image.open(io.BytesIO(response)).convert("RGB") 90 | elif obj_type == "url": 91 | return AnyHttpUrl(response.decode("utf-8")) 92 | elif obj_type == "vid": 93 | # Read the binary response as a video file and write it to a temporary file 94 | assert ( 95 | headers["Content-Type"] == "video/mp4" 96 | ), f"Expected video/mp4, got {headers['Content-Type']}" 97 | 98 | # Parse Content-Disposition header to get file extension 99 | ext = "mp4" # default extension 100 | disposition = headers.get("Content-Disposition") or headers.get( 101 | "content-disposition" 102 | ) 103 | if disposition: 104 | params = _get_disposition_params(disposition) 105 | filename = params.get("filename") 106 | if filename: 107 | suffix = Path(filename).suffix 108 | if suffix: 109 | ext = suffix.lstrip(".") 110 | 111 | # Build cache path with session_id and object_id 112 | safe_session_id = session_id.replace("-", "") 113 | tmp_path: Path = VLMRUN_CACHE_DIR / f"{safe_session_id}_{object_id}.{ext}" 114 | 115 | # Return cached version if it exists 116 | if tmp_path.exists(): 117 | return tmp_path 118 | 119 | with tmp_path.open("wb") as f: 120 | f.write(response) 121 | return tmp_path 122 | else: 123 | return response 124 | 125 | def list(self, session_id: str) -> None: 126 | """List artifacts for a session. 127 | 128 | Args: 129 | session_id: Session ID to list artifacts for 130 | 131 | Raises: 132 | NotImplementedError: This method is not yet implemented 133 | """ 134 | raise NotImplementedError("Artifacts.list() is not yet implemented") 135 | -------------------------------------------------------------------------------- /vlmrun/cli/_cli/datasets.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from typing import List 3 | from pathlib import Path 4 | from rich.table import Table 5 | from rich.console import Console 6 | 7 | from loguru import logger 8 | from vlmrun.client import VLMRun 9 | from vlmrun.client.types import DatasetResponse 10 | from vlmrun.constants import VLMRUN_TMP_DIR 11 | 12 | app = typer.Typer( 13 | help="Dataset operations", 14 | add_completion=False, 15 | no_args_is_help=True, 16 | ) 17 | 18 | 19 | @app.command() 20 | def create( 21 | ctx: typer.Context, 22 | directory: Path = typer.Option( 23 | ..., 24 | help="Directory containing images, PDFs or videos", 25 | exists=True, 26 | dir_okay=True, 27 | file_okay=False, 28 | ), 29 | domain: str = typer.Option(..., help="Domain for the dataset"), 30 | dataset_name: str = typer.Option(..., help="Name of the dataset"), 31 | dataset_type: str = typer.Option( 32 | ..., help="Type of dataset ('images', 'documents', 'videos')" 33 | ), 34 | wandb_base_url: str = typer.Option( 35 | "https://api.wandb.ai", help="Base URL for Weights & Biases" 36 | ), 37 | wandb_project_name: str = typer.Option( 38 | None, help="Project name for Weights & Biases" 39 | ), 40 | wandb_api_key: str = typer.Option(None, help="API key for Weights & Biases"), 41 | ) -> None: 42 | """Create a dataset from a directory of images. 43 | 44 | This command will: 45 | 1. Create a tar.gz archive from the image directory 46 | 2. Upload the archive 47 | 3. Create a dataset using the uploaded file 48 | """ 49 | client: VLMRun = ctx.obj 50 | 51 | valid_types = {"images": "images", "pdfs": "documents", "videos": "videos"} 52 | if dataset_type not in valid_types: 53 | typer.echo("Error: Invalid dataset type", err=True) 54 | raise typer.Exit(1) 55 | dataset_type = valid_types[dataset_type] 56 | 57 | if dataset_type == "images": 58 | image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"} 59 | files = [ 60 | p 61 | for p in directory.rglob("*") 62 | if p.is_file() and p.suffix.lower() in image_extensions 63 | ] 64 | elif dataset_type == "documents": 65 | files = [ 66 | p 67 | for p in directory.rglob("*") 68 | if p.is_file() and p.suffix.lower() == ".pdf" 69 | ] 70 | 71 | elif dataset_type == "videos": 72 | video_extensions = {".mp4", ".webm"} 73 | files = [ 74 | p 75 | for p in directory.rglob("*") 76 | if p.is_file() and p.suffix.lower() in video_extensions 77 | ] 78 | else: 79 | raise ValueError(f"Invalid dataset type: {dataset_type}") 80 | 81 | if not len(files): 82 | typer.echo( 83 | f"Error: No files of type={dataset_type}] found in directory", err=True 84 | ) 85 | raise typer.Exit(1) 86 | 87 | typer.echo(f"Available files [n={len(files)}, dataset_type={dataset_type}]") 88 | 89 | # Create dataset 90 | try: 91 | dataset_response = client.datasets.create( 92 | domain=domain, 93 | dataset_directory=directory, 94 | dataset_name=dataset_name, 95 | dataset_type=dataset_type, 96 | wandb_base_url=wandb_base_url, 97 | wandb_project_name=wandb_project_name, 98 | wandb_api_key=wandb_api_key, 99 | ) 100 | typer.echo(f"Dataset created successfully! ID: {dataset_response.id}") 101 | 102 | # Verify dataset creation 103 | dataset_info = client.datasets.get(dataset_response.id) 104 | typer.echo(f"Dataset created with ID: {dataset_info.id}") 105 | 106 | finally: 107 | logger.debug( 108 | f"Generated dataset [dataset_id={dataset_response.id}, tmp_dir={VLMRUN_TMP_DIR}]" 109 | ) 110 | 111 | 112 | @app.command() 113 | def list( 114 | ctx: typer.Context, 115 | skip: int = typer.Option(0, help="Skip the first N datasets"), 116 | limit: int = typer.Option(10, help="Limit the number of datasets to list"), 117 | ) -> None: 118 | """List datasets.""" 119 | client: VLMRun = ctx.obj 120 | datasets: List[DatasetResponse] = client.datasets.list(skip=skip, limit=limit) 121 | 122 | console = Console() 123 | table = Table(show_header=True) 124 | table.add_column("id") 125 | table.add_column("dataset_name") 126 | table.add_column("dataset_type") 127 | table.add_column("domain") 128 | table.add_column("created_at") 129 | table.add_column("completed_at") 130 | table.add_column("status") 131 | table.add_column("wandb_url") 132 | 133 | for dataset in datasets: 134 | table.add_row( 135 | dataset.id, 136 | dataset.dataset_name, 137 | dataset.dataset_type, 138 | dataset.domain, 139 | dataset.created_at.strftime("%Y-%m-%d %H:%M:%S"), 140 | ( 141 | dataset.completed_at.strftime("%Y-%m-%d %H:%M:%S") 142 | if dataset.completed_at 143 | else "" 144 | ), 145 | dataset.status, 146 | dataset.wandb_url, 147 | ) 148 | console.print(table) 149 | -------------------------------------------------------------------------------- /vlmrun/client/hub.py: -------------------------------------------------------------------------------- 1 | """VLM Run Hub API implementation.""" 2 | 3 | from typing import TYPE_CHECKING, List, Type, Optional 4 | from pydantic import BaseModel 5 | 6 | from vlmrun.client.base_requestor import APIError 7 | from vlmrun.client.types import ( 8 | HubSchemaResponse, 9 | HubInfoResponse, 10 | HubDomainInfo, 11 | ) 12 | import cachetools 13 | from cachetools.keys import hashkey 14 | 15 | 16 | if TYPE_CHECKING: 17 | from vlmrun.types.abstract import VLMRunProtocol 18 | 19 | 20 | @cachetools.cached( 21 | cache=cachetools.TTLCache(maxsize=100, ttl=3600), 22 | key=lambda _client, domain: hashkey(domain), # noqa: B007 23 | ) 24 | def get_response_model(client, domain: str) -> Type[BaseModel]: 25 | """Get the schema type for a hub domain. 26 | 27 | Note: This function is cached to avoid re-fetching the schema from the API. 28 | """ 29 | schema_response: HubSchemaResponse = client.hub.get_schema(domain) 30 | return schema_response.response_model 31 | 32 | 33 | class Hub: 34 | """Hub API for VLM Run. 35 | 36 | This module provides access to the hub routes for managing schemas and domains. 37 | """ 38 | 39 | def __init__(self, client: "VLMRunProtocol") -> None: 40 | """Initialize Hub resource with VLMRun instance. 41 | 42 | Args: 43 | client: VLM Run API instance 44 | """ 45 | self._client = client 46 | 47 | def info(self) -> HubInfoResponse: 48 | """Get the hub info. 49 | 50 | Returns: 51 | HubInfoResponse containing hub version 52 | 53 | Example: 54 | >>> client.hub.info() 55 | { 56 | "version": "1.0.0" 57 | } 58 | 59 | Raises: 60 | APIError: If the request fails 61 | """ 62 | try: 63 | response, _, _ = self._client.requestor.request( 64 | method="GET", url="/hub/info", raw_response=False 65 | ) 66 | if not isinstance(response, dict): 67 | raise APIError("Expected dict response from health check") 68 | return HubInfoResponse(**response) 69 | except Exception as e: 70 | raise APIError(f"Failed to check hub health: {str(e)}") 71 | 72 | def list_domains(self) -> List[HubDomainInfo]: 73 | """Get the list of supported domains. 74 | 75 | Returns: 76 | List of domain strings 77 | 78 | Example: 79 | >>> client.hub.list_domains() 80 | [ 81 | "document.invoice", 82 | "document.receipt", 83 | "document.utility_bill" 84 | ] 85 | 86 | Raises: 87 | APIError: If the request fails 88 | """ 89 | try: 90 | response, _, _ = self._client.requestor.request( 91 | method="GET", url="/hub/domains", raw_response=False 92 | ) 93 | if not isinstance(response, list): 94 | raise TypeError("Expected list response") 95 | return [HubDomainInfo(**domain) for domain in response] 96 | except Exception as e: 97 | raise APIError(f"Failed to list domains: {str(e)}") 98 | 99 | def get_schema( 100 | self, domain: str, gql_stmt: Optional[str] = None 101 | ) -> HubSchemaResponse: 102 | """Get the JSON schema for a given domain. 103 | 104 | Args: 105 | domain: Domain identifier (e.g. "document.invoice") 106 | gql_stmt: GraphQL statement for the domain 107 | 108 | Returns: 109 | HubSchemaQueryResponse containing: 110 | - schema_json: The JSON schema for the domain 111 | - schema_version: Schema version string 112 | - schema_hash: First 8 characters of schema hash 113 | 114 | Example: 115 | >>> response = client.hub.get_schema("document.invoice") 116 | >>> print(response.schema_version) 117 | "1.0.0" 118 | >>> print(response.schema_json) 119 | { 120 | "type": "object", 121 | "properties": {...} 122 | } 123 | 124 | Raises: 125 | APIError: If the request fails or domain is not found 126 | """ 127 | try: 128 | response, _, _ = self._client.requestor.request( 129 | method="POST", 130 | url="/hub/schema", 131 | data={"domain": domain, "gql_stmt": gql_stmt}, 132 | raw_response=False, 133 | ) 134 | if not isinstance(response, dict): 135 | raise APIError("Expected dict response from schema query") 136 | return HubSchemaResponse(**response) 137 | except Exception as e: 138 | raise APIError(f"Failed to get schema for domain {domain}: {str(e)}") 139 | 140 | def get_pydantic_model(self, domain: str) -> Type[BaseModel]: 141 | """Get the Pydantic model for a given domain. 142 | 143 | Args: 144 | domain: Domain identifier (e.g. "document.invoice") 145 | 146 | Returns: 147 | Type[BaseModel]: The Pydantic model class for the domain 148 | 149 | Raises: 150 | APIError: If the domain is not found 151 | """ 152 | try: 153 | return get_response_model(self._client, domain) 154 | except KeyError: 155 | raise APIError(f"Domain not found: {domain}") 156 | -------------------------------------------------------------------------------- /vlmrun/common/utils.py: -------------------------------------------------------------------------------- 1 | """General utilities for VLMRun.""" 2 | 3 | from datetime import datetime 4 | from io import BytesIO 5 | from pathlib import Path 6 | from typing import Union, Literal, Dict, Any 7 | from loguru import logger 8 | 9 | import tarfile 10 | import requests 11 | from PIL import Image, ImageOps 12 | from vlmrun.constants import VLMRUN_TMP_DIR, VLMRUN_CACHE_DIR 13 | from vlmrun.common.image import _open_image_with_exif 14 | 15 | 16 | # HTTP request headers 17 | _HEADERS = { 18 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:106.0) Gecko/20100101 Firefox/106.0", 19 | "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", 20 | "Accept-Language": "en-US,en;q=0.5", 21 | "Accept-Encoding": "gzip, deflate", 22 | "Connection": "keep-alive", 23 | "Upgrade-Insecure-Requests": "1", 24 | "Sec-Fetch-Dest": "document", 25 | "Sec-Fetch-Mode": "navigate", 26 | "Sec-Fetch-Site": "none", 27 | "Sec-Fetch-User": "?1", 28 | "Cache-Control": "max-age=0", 29 | } 30 | 31 | 32 | def remote_image(url: Union[str, Path]) -> Image.Image: 33 | """Load an image from a URL or local path. 34 | 35 | Args: 36 | url: URL or Path to the image 37 | 38 | Returns: 39 | PIL Image in RGB format 40 | 41 | Raises: 42 | ValueError: If URL/path is invalid 43 | requests.exceptions.RequestException: If there's an error downloading the image 44 | IOError: If there's an error processing the image 45 | """ 46 | if isinstance(url, Path) or (isinstance(url, str) and not url.startswith("http")): 47 | try: 48 | return _open_image_with_exif(url) 49 | except Exception as e: 50 | raise ValueError(f"Failed to open image from path={url}") from e 51 | 52 | try: 53 | response = requests.get(url, headers=_HEADERS, timeout=10) 54 | response.raise_for_status() 55 | image = Image.open(BytesIO(response.content)) 56 | try: 57 | image = ImageOps.exif_transpose(image) 58 | except Exception: 59 | pass 60 | return image.convert("RGB") 61 | except requests.exceptions.RequestException: 62 | # Let request exceptions propagate through 63 | raise 64 | except Exception as e: 65 | raise ValueError(f"Failed to process image from {url}") from e 66 | 67 | 68 | def download_image(url: str) -> Image.Image: 69 | """Download an image from a URL. 70 | 71 | Args: 72 | url: URL of the image to download 73 | 74 | Returns: 75 | PIL Image in RGB format 76 | 77 | Raises: 78 | ValueError: If URL format is invalid 79 | requests.exceptions.RequestException: If there's an error downloading the image 80 | """ 81 | if not isinstance(url, str): 82 | raise ValueError(f"URL must be a string, got {type(url)}") 83 | if not url.startswith(("http://", "https://")): 84 | raise ValueError(f"Invalid URL format: {url}") 85 | return remote_image(url) 86 | 87 | 88 | def create_archive(directory: Path, name: str) -> Path: 89 | """Create a tar.gz archive from a directory. 90 | 91 | Args: 92 | directory: Path to directory to archive 93 | name: Name of the archive 94 | Returns: 95 | str: Path to created archive file 96 | 97 | Raises: 98 | ValueError: If directory does not exist 99 | """ 100 | if isinstance(directory, str): 101 | directory = Path(directory) 102 | 103 | if not directory.is_dir(): 104 | raise ValueError(f"Directory does not exist: {directory}") 105 | 106 | # Create archive in temp directory 107 | logger.debug( 108 | f"Creating tar.gz file from directory [path={directory}, n_files={len(list(directory.iterdir()))}]" 109 | ) 110 | date_str = datetime.now().strftime("%Y%m%d") 111 | archive_name = f"{name}_{date_str}" 112 | archive_path = VLMRUN_TMP_DIR / f"{archive_name}.tar.gz" 113 | 114 | # Check if tar.gz file already exists 115 | if archive_path.exists(): 116 | logger.debug(f"Tar.gz file already exists [path={archive_path}]") 117 | return archive_path 118 | 119 | with tarfile.open(str(archive_path), "w:gz") as tar: 120 | tar.add(directory, arcname=archive_name) 121 | logger.debug(f"Created tar.gz file [path={archive_path}]") 122 | return archive_path 123 | 124 | 125 | def download_artifact( 126 | url: str, format: Literal["image", "json", "file"] 127 | ) -> Union[Image.Image, Dict[str, Any], Path]: 128 | if not url.startswith("http"): 129 | raise ValueError(f"Invalid URL: {url}") 130 | if format == "image": 131 | bytes = requests.get(url, headers=_HEADERS).content 132 | image = Image.open(BytesIO(bytes)) 133 | try: 134 | image = ImageOps.exif_transpose(image) 135 | except Exception: 136 | pass 137 | return image.convert("RGB") 138 | elif format == "json": 139 | return requests.get(url, headers=_HEADERS).json() 140 | elif format == "file": 141 | path = VLMRUN_CACHE_DIR / "downloads" / Path(url).name 142 | path.parent.mkdir(parents=True, exist_ok=True) 143 | if not path.exists(): 144 | with requests.get(url, headers=_HEADERS, stream=True) as r: 145 | r.raise_for_status() 146 | with path.open("wb") as f: 147 | for chunk in r.iter_content(chunk_size=8192): 148 | f.write(chunk) 149 | else: 150 | logger.debug(f"File already exists [path={path}]") 151 | return path 152 | else: 153 | raise ValueError(f"Invalid format: {format}") 154 | -------------------------------------------------------------------------------- /vlmrun/client/exceptions.py: -------------------------------------------------------------------------------- 1 | """VLM Run API exceptions.""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Dict, Optional 5 | 6 | 7 | @dataclass 8 | class VLMRunError(Exception): 9 | """Base exception for all VLM Run errors.""" 10 | 11 | message: str 12 | 13 | def __post_init__(self): 14 | super().__init__(self.message) 15 | 16 | 17 | @dataclass 18 | class APIError(VLMRunError): 19 | """Base exception for API errors.""" 20 | 21 | message: str 22 | http_status: Optional[int] = None 23 | headers: Dict[str, str] = field(default_factory=dict) 24 | request_id: Optional[str] = None 25 | error_type: Optional[str] = None 26 | suggestion: Optional[str] = None 27 | 28 | def __str__(self) -> str: 29 | """Return string representation of error.""" 30 | parts = [] 31 | if self.http_status: 32 | parts.append(f"status={self.http_status}") 33 | if self.error_type: 34 | parts.append(f"type={self.error_type}") 35 | if self.request_id: 36 | parts.append(f"id={self.request_id}") 37 | 38 | formatted = f"[{', '.join(parts)}] {self.message}" 39 | if self.suggestion: 40 | formatted += f" (Suggestion: {self.suggestion})" 41 | return formatted 42 | 43 | 44 | @dataclass 45 | class AuthenticationError(APIError): 46 | """Exception raised when authentication fails.""" 47 | 48 | message: str = "Authentication failed" 49 | http_status: int = 401 50 | headers: Dict[str, str] = field(default_factory=dict) 51 | request_id: Optional[str] = None 52 | error_type: str = "authentication_error" 53 | suggestion: str = "Check your API key and ensure it is valid" 54 | 55 | 56 | @dataclass 57 | class ValidationError(APIError): 58 | """Exception raised when request validation fails.""" 59 | 60 | message: str = "Validation failed" 61 | http_status: int = 400 62 | headers: Dict[str, str] = field(default_factory=dict) 63 | request_id: Optional[str] = None 64 | error_type: str = "validation_error" 65 | suggestion: str = "Check your request parameters" 66 | 67 | 68 | @dataclass 69 | class RateLimitError(APIError): 70 | """Exception raised when rate limit is exceeded.""" 71 | 72 | message: str = "Rate limit exceeded" 73 | http_status: int = 429 74 | headers: Dict[str, str] = field(default_factory=dict) 75 | request_id: Optional[str] = None 76 | error_type: str = "rate_limit_error" 77 | suggestion: str = ( 78 | "Reduce request frequency or contact support to increase your rate limit" 79 | ) 80 | 81 | 82 | @dataclass 83 | class ServerError(APIError): 84 | """Exception raised when server returns 5xx error.""" 85 | 86 | message: str = "Server error" 87 | http_status: int = 500 88 | headers: Dict[str, str] = field(default_factory=dict) 89 | request_id: Optional[str] = None 90 | error_type: str = "server_error" 91 | suggestion: str = "Please try again later or contact support if the issue persists" 92 | 93 | 94 | @dataclass 95 | class ResourceNotFoundError(APIError): 96 | """Exception raised when resource is not found.""" 97 | 98 | message: str = "Resource not found" 99 | http_status: int = 404 100 | headers: Dict[str, str] = field(default_factory=dict) 101 | request_id: Optional[str] = None 102 | error_type: str = "not_found_error" 103 | suggestion: str = "Check the resource ID or path" 104 | 105 | 106 | @dataclass 107 | class ClientError(VLMRunError): 108 | """Base exception for client-side errors.""" 109 | 110 | message: str 111 | error_type: Optional[str] = None 112 | suggestion: Optional[str] = None 113 | 114 | def __str__(self) -> str: 115 | """Return string representation of error.""" 116 | parts = [] 117 | if self.error_type: 118 | parts.append(f"type={self.error_type}") 119 | 120 | formatted = f"[{', '.join(parts)}] {self.message}" 121 | if self.suggestion: 122 | formatted += f" (Suggestion: {self.suggestion})" 123 | return formatted 124 | 125 | 126 | @dataclass 127 | class ConfigurationError(ClientError): 128 | """Exception raised when client configuration is invalid.""" 129 | 130 | message: str = "Invalid configuration" 131 | error_type: str = "configuration_error" 132 | suggestion: str = "Check your client configuration" 133 | 134 | 135 | @dataclass 136 | class DependencyError(ClientError): 137 | """Exception raised when a required dependency is missing.""" 138 | 139 | message: str = "Missing dependency" 140 | error_type: str = "dependency_error" 141 | suggestion: str = "Install the required dependency" 142 | 143 | 144 | @dataclass 145 | class InputError(ClientError): 146 | """Exception raised when input is invalid.""" 147 | 148 | message: str = "Invalid input" 149 | error_type: str = "input_error" 150 | suggestion: str = "Check your input parameters" 151 | 152 | 153 | @dataclass 154 | class RequestTimeoutError(APIError): 155 | """Exception raised when a request times out.""" 156 | 157 | message: str = "Request timed out" 158 | http_status: int = 408 159 | headers: Dict[str, str] = field(default_factory=dict) 160 | request_id: Optional[str] = None 161 | error_type: str = "timeout_error" 162 | suggestion: str = "Try again later or increase the timeout" 163 | 164 | 165 | @dataclass 166 | class NetworkError(APIError): 167 | """Exception raised when a network error occurs.""" 168 | 169 | message: str = "Network error" 170 | http_status: Optional[int] = None 171 | headers: Dict[str, str] = field(default_factory=dict) 172 | request_id: Optional[str] = None 173 | error_type: str = "network_error" 174 | suggestion: str = "Check your internet connection and try again" 175 | -------------------------------------------------------------------------------- /tests/common/test_video.py: -------------------------------------------------------------------------------- 1 | """Tests for video utilities.""" 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from vlmrun.common.video import VideoReader, VideoWriter 9 | 10 | # Constants for real test video properties 11 | REAL_VIDEO_WIDTH = 294 12 | REAL_VIDEO_HEIGHT = 240 13 | REAL_VIDEO_FRAMES = 168 14 | REAL_VIDEO_FPS = 25 15 | 16 | 17 | @pytest.fixture 18 | def sample_video() -> Path: 19 | """Return path to the test video file. 20 | 21 | Returns: 22 | Path: Path to the test video file. 23 | """ 24 | return Path(__file__).parent.parent / "test_data" / "test.mp4" 25 | 26 | 27 | def test_video_reader_basic(sample_video): 28 | """Test basic VideoReader functionality.""" 29 | with VideoReader(sample_video) as reader: 30 | # Test length 31 | assert len(reader) == REAL_VIDEO_FRAMES 32 | 33 | # Test iteration 34 | frame_count = 0 35 | for frame in reader: 36 | assert isinstance(frame, np.ndarray) 37 | assert frame.shape == (REAL_VIDEO_HEIGHT, REAL_VIDEO_WIDTH, 3) 38 | assert frame.dtype == np.uint8 39 | frame_count += 1 40 | assert frame_count == REAL_VIDEO_FRAMES 41 | 42 | # Test position 43 | assert reader.pos() == REAL_VIDEO_FRAMES 44 | 45 | 46 | def test_video_reader_seeking(sample_video): 47 | """Test VideoReader seeking functionality.""" 48 | with VideoReader(sample_video) as reader: 49 | # Test seeking to specific frame 50 | mid_frame = REAL_VIDEO_FRAMES // 2 51 | reader.seek(mid_frame) 52 | assert reader.pos() == mid_frame 53 | 54 | # Test reading frame after seeking 55 | frame = next(reader) 56 | assert isinstance(frame, np.ndarray) 57 | assert frame.shape == (REAL_VIDEO_HEIGHT, REAL_VIDEO_WIDTH, 3) 58 | 59 | # Test seeking to start 60 | reader.reset() 61 | assert reader.pos() == 0 62 | 63 | # Test seeking to last frame 64 | reader.seek(len(reader) - 1) 65 | frame = next(reader) 66 | assert isinstance(frame, np.ndarray) 67 | with pytest.raises(StopIteration): 68 | next(reader) 69 | 70 | 71 | def test_video_reader_getitem(sample_video): 72 | """Test VideoReader indexing functionality.""" 73 | with VideoReader(sample_video) as reader: 74 | # Test single frame access 75 | frame = reader[0] 76 | assert isinstance(frame, np.ndarray) 77 | assert frame.shape == (REAL_VIDEO_HEIGHT, REAL_VIDEO_WIDTH, 3) 78 | 79 | # Test multiple frame access 80 | frames = reader[[0, REAL_VIDEO_FRAMES // 2, REAL_VIDEO_FRAMES - 1]] 81 | assert isinstance(frames, list) 82 | assert len(frames) == 3 83 | for frame in frames: 84 | assert isinstance(frame, np.ndarray) 85 | assert frame.shape == (REAL_VIDEO_HEIGHT, REAL_VIDEO_WIDTH, 3) 86 | 87 | # Test invalid index 88 | with pytest.raises(IndexError): 89 | reader[len(reader)] 90 | 91 | # Test invalid index type 92 | with pytest.raises(TypeError): 93 | reader["invalid"] 94 | 95 | 96 | def test_video_reader_errors(): 97 | """Test VideoReader error handling.""" 98 | # Test non-existent file 99 | with pytest.raises(FileNotFoundError): 100 | VideoReader("nonexistent.mp4") 101 | 102 | 103 | def test_video_writer_basic(tmp_path): 104 | """Test basic VideoWriter functionality.""" 105 | output_path = tmp_path / "output.mp4" 106 | frame = np.zeros((480, 640, 3), dtype=np.uint8) 107 | 108 | # Test writing frames 109 | with VideoWriter(output_path, fps=30.0) as writer: 110 | for _ in range(10): 111 | writer.write(frame) 112 | 113 | # Verify output file exists 114 | assert output_path.exists() 115 | 116 | # Verify video can be read back 117 | with VideoReader(output_path) as reader: 118 | assert len(reader) == 10 119 | for frame in reader: 120 | assert isinstance(frame, np.ndarray) 121 | assert frame.shape == (480, 640, 3) 122 | assert frame.dtype == np.uint8 123 | 124 | 125 | def test_video_writer_errors(tmp_path): 126 | """Test VideoWriter error handling.""" 127 | output_path = tmp_path / "output.mp4" 128 | 129 | # Create a video file 130 | with VideoWriter(output_path) as writer: 131 | frame = np.zeros((480, 640, 3), dtype=np.uint8) 132 | writer.write(frame) 133 | 134 | # Test file exists error 135 | with pytest.raises(FileExistsError): 136 | VideoWriter(output_path) 137 | 138 | 139 | def test_video_context_managers(sample_video, tmp_path): 140 | """Test context manager functionality for both VideoReader and VideoWriter.""" 141 | # Test VideoReader context manager 142 | with VideoReader(sample_video) as reader: 143 | assert len(reader) > 0 144 | frame = next(reader) 145 | assert isinstance(frame, np.ndarray) 146 | 147 | # Verify reader is closed 148 | assert reader._video is None 149 | 150 | # Test VideoWriter context manager 151 | output_path = tmp_path / "output.mp4" 152 | frame = np.zeros((480, 640, 3), dtype=np.uint8) 153 | 154 | with VideoWriter(output_path) as writer: 155 | writer.write(frame) 156 | 157 | # Verify writer is closed 158 | assert writer.writer is None 159 | assert output_path.exists() 160 | 161 | 162 | @pytest.fixture 163 | def real_video_path() -> Path: 164 | """Return path to the real test video file. 165 | 166 | Returns: 167 | Path: Path to the test video file. 168 | """ 169 | return Path(__file__).parent.parent / "test_data" / "test.mp4" 170 | 171 | 172 | def test_video_reader_real_video(real_video_path): 173 | """Test VideoReader with a real video file.""" 174 | with VideoReader(real_video_path) as reader: 175 | # Test basic properties 176 | assert len(reader) == REAL_VIDEO_FRAMES 177 | 178 | # Test frame dimensions 179 | frame = next(reader) 180 | assert isinstance(frame, np.ndarray) 181 | assert frame.shape == (REAL_VIDEO_HEIGHT, REAL_VIDEO_WIDTH, 3) 182 | assert frame.dtype == np.uint8 183 | 184 | # Test seeking 185 | reader.seek(REAL_VIDEO_FRAMES // 2) 186 | mid_frame = next(reader) 187 | assert isinstance(mid_frame, np.ndarray) 188 | assert mid_frame.shape == (REAL_VIDEO_HEIGHT, REAL_VIDEO_WIDTH, 3) 189 | 190 | # Test reading all frames 191 | reader.reset() 192 | frame_count = sum(1 for _ in reader) 193 | assert frame_count == REAL_VIDEO_FRAMES 194 | -------------------------------------------------------------------------------- /vlmrun/client/client.py: -------------------------------------------------------------------------------- 1 | """VLM Run API client implementation.""" 2 | 3 | from dataclasses import dataclass 4 | import os 5 | from functools import cached_property 6 | from typing import Optional, List, Type 7 | from pydantic import BaseModel 8 | 9 | from vlmrun.version import __version__ 10 | from vlmrun.client.base_requestor import APIRequestor 11 | from vlmrun.client.datasets import Datasets 12 | from vlmrun.client.files import Files 13 | from vlmrun.client.hub import Hub 14 | from vlmrun.client.models import Models 15 | from vlmrun.client.fine_tuning import Finetuning 16 | from vlmrun.client.predictions import ( 17 | Predictions, 18 | ImagePredictions, 19 | VideoPredictions, 20 | DocumentPredictions, 21 | AudioPredictions, 22 | ) 23 | from vlmrun.client.feedback import Feedback 24 | from vlmrun.client.agent import Agent 25 | from vlmrun.client.executions import Executions 26 | from vlmrun.client.artifacts import Artifacts 27 | from vlmrun.constants import DEFAULT_BASE_URL 28 | from vlmrun.client.types import SchemaResponse, DomainInfo, GenerationConfig 29 | from vlmrun.client.exceptions import ( 30 | ConfigurationError, 31 | AuthenticationError, 32 | ) 33 | 34 | 35 | @dataclass 36 | class VLMRun: 37 | """VLM Run API client. 38 | 39 | Attributes: 40 | api_key: API key for authentication. Can be provided through constructor 41 | or VLMRUN_API_KEY environment variable. 42 | base_url: Base URL for API. Defaults to None, which falls back to 43 | VLMRUN_BASE_URL environment variable or https://api.vlm.run/v1. 44 | timeout: Request timeout in seconds. Defaults to 120.0. 45 | max_retries: Maximum number of retry attempts for failed requests. Defaults to 5. 46 | files: Files resource for managing files 47 | models: Models resource for accessing available models 48 | finetune: Fine-tuning resource for model fine-tuning 49 | """ 50 | 51 | api_key: Optional[str] = None 52 | base_url: Optional[str] = None 53 | timeout: float = 120.0 54 | max_retries: int = 5 55 | 56 | def __post_init__(self): 57 | """Initialize the client after dataclass initialization. 58 | 59 | This method handles environment variable fallbacks: 60 | - api_key: Falls back to VLMRUN_API_KEY environment variable 61 | - base_url: Can be overridden by constructor or VLMRUN_BASE_URL environment variable 62 | """ 63 | # Handle API key first 64 | if not self.api_key: # Handle both None and empty string 65 | self.api_key = os.getenv("VLMRUN_API_KEY", None) 66 | if not self.api_key: # Still None or empty after env check 67 | raise ConfigurationError( 68 | message="Missing API key", 69 | error_type="missing_api_key", 70 | suggestion="Please provide your VLM Run API key:\n\n" 71 | "1. Set it in your code:\n" 72 | " client = VLMRun(api_key='your-api-key')\n\n" 73 | "2. Or set the environment variable:\n" 74 | " export VLMRUN_API_KEY='your-api-key'\n\n" 75 | "Get your API key at https://app.vlm.run/dashboard", 76 | ) 77 | 78 | # Handle base URL 79 | if self.base_url is None: 80 | self.base_url = os.getenv("VLMRUN_BASE_URL", DEFAULT_BASE_URL) 81 | 82 | # Initialize requestor for API key validation 83 | requestor = APIRequestor( 84 | self, timeout=self.timeout, max_retries=self.max_retries 85 | ) 86 | 87 | # Validate API key by making a health check request 88 | try: 89 | _, status_code, _ = requestor.request( 90 | method="GET", url="/health", raw_response=True 91 | ) 92 | if status_code != 200: 93 | raise AuthenticationError( 94 | message="Invalid API key", 95 | error_type="invalid_api_key", 96 | suggestion="Please check your API key and ensure it is valid. You can get your API key at https://app.vlm.run/dashboard", 97 | ) 98 | except AuthenticationError as e: 99 | raise AuthenticationError( 100 | message="Invalid API key", 101 | error_type="invalid_api_key", 102 | suggestion="Please check your API key and ensure it is valid. You can get your API key at https://app.vlm.run/dashboard", 103 | ) from e 104 | 105 | # Initialize resources 106 | self.datasets = Datasets(self) 107 | self.files = Files(self) 108 | self.hub = Hub(self) 109 | self.models = Models(self) 110 | self.fine_tuning = Finetuning(self) 111 | self.predictions = Predictions(self) 112 | self.image = ImagePredictions(self) 113 | self.document = DocumentPredictions(self) 114 | self.document._requestor._timeout = 120.0 115 | self.audio = AudioPredictions(self) 116 | self.audio._requestor._timeout = 120.0 117 | self.video = VideoPredictions(self) 118 | self.video._requestor._timeout = 120.0 119 | self.feedback = Feedback(self) 120 | self.agent = Agent(self) 121 | self.executions = Executions(self) 122 | self.artifacts = Artifacts(self) 123 | 124 | def __repr__(self): 125 | return f"VLMRun(base_url={self.base_url}, api_key={f'{self.api_key[:8]}...' if self.api_key else 'None'}, version={self.version})" 126 | 127 | @property 128 | def version(self): 129 | return __version__ 130 | 131 | @cached_property 132 | def requestor(self): 133 | """Requestor for the API.""" 134 | return APIRequestor(self, timeout=self.timeout, max_retries=self.max_retries) 135 | 136 | def healthcheck(self) -> bool: 137 | """Check the health of the API.""" 138 | _, status_code, _ = self.requestor.request( 139 | method="GET", url="/health", raw_response=True 140 | ) 141 | return status_code == 200 142 | 143 | def get_type( 144 | self, domain: str, config: Optional[GenerationConfig] = None 145 | ) -> Type[BaseModel]: 146 | """Get the type for a domain.""" 147 | return self.get_schema(domain, config=config).response_model 148 | 149 | def get_schema( 150 | self, domain: str, config: Optional[GenerationConfig] = None 151 | ) -> SchemaResponse: 152 | """Get the schema for a domain. 153 | 154 | Args: 155 | domain: Domain name (e.g. "document.invoice") 156 | config: Generation config 157 | 158 | Returns: 159 | Schema response containing GraphQL schema and metadata 160 | """ 161 | if config is None: 162 | config = GenerationConfig() 163 | response, status_code, headers = self.requestor.request( 164 | method="POST", 165 | url="/schema", 166 | data={"domain": domain, "config": config.model_dump()}, 167 | ) 168 | return SchemaResponse(**response) 169 | 170 | def list_domains(self) -> List[DomainInfo]: 171 | """List all available domains. 172 | 173 | Returns: 174 | List of domain names 175 | """ 176 | response, status_code, headers = self.requestor.request( 177 | method="GET", 178 | url="/domains", 179 | ) 180 | return [DomainInfo(**domain) for domain in response] 181 | -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import pytest 4 | from functools import lru_cache 5 | from unittest.mock import patch 6 | from vlmrun.client import VLMRun 7 | from vlmrun.client.exceptions import ConfigurationError, AuthenticationError 8 | 9 | 10 | def test_client_with_api_key(monkeypatch): 11 | """Test client initialization with API key provided in constructor.""" 12 | monkeypatch.delenv("VLMRUN_API_KEY", raising=False) 13 | monkeypatch.delenv("VLMRUN_BASE_URL", raising=False) # Ensure clean environment 14 | 15 | # Mock the health check request 16 | with patch("vlmrun.client.base_requestor.APIRequestor.request") as mock_request: 17 | mock_request.return_value = (None, 200, {}) 18 | client = VLMRun(api_key="test-key") 19 | assert client.api_key == "test-key" 20 | assert ( 21 | client.base_url == "https://api.vlm.run/v1" 22 | ) # Default URL with /v1 suffix 23 | 24 | 25 | def test_client_with_env_api_key(monkeypatch): 26 | """Test client initialization with API key from environment variable.""" 27 | monkeypatch.setenv("VLMRUN_API_KEY", "env-key") 28 | monkeypatch.delenv("VLMRUN_BASE_URL", raising=False) # Ensure clean environment 29 | 30 | # Mock the health check request 31 | with patch("vlmrun.client.base_requestor.APIRequestor.request") as mock_request: 32 | mock_request.return_value = (None, 200, {}) 33 | client = VLMRun() 34 | assert client.api_key == "env-key" 35 | assert ( 36 | client.base_url == "https://api.vlm.run/v1" 37 | ) # Default URL with /v1 suffix 38 | 39 | 40 | def test_client_with_base_url(): 41 | """Test client initialization with custom base URL.""" 42 | # Mock the health check request 43 | with patch("vlmrun.client.base_requestor.APIRequestor.request") as mock_request: 44 | mock_request.return_value = (None, 200, {}) 45 | client = VLMRun(api_key="test-key", base_url="https://custom.api") 46 | assert client.api_key == "test-key" 47 | assert client.base_url == "https://custom.api" 48 | 49 | 50 | def test_client_with_env_base_url(monkeypatch): 51 | """Test client initialization with base URL from environment variable.""" 52 | monkeypatch.setenv("VLMRUN_BASE_URL", "https://env.api") 53 | 54 | # Mock the health check request 55 | with patch("vlmrun.client.base_requestor.APIRequestor.request") as mock_request: 56 | mock_request.return_value = (None, 200, {}) 57 | client = VLMRun(api_key="test-key") 58 | assert client.api_key == "test-key" 59 | assert client.base_url == "https://env.api" 60 | 61 | 62 | def test_client_missing_api_key(monkeypatch): 63 | """Test client initialization fails when API key is missing.""" 64 | monkeypatch.delenv("VLMRUN_API_KEY", raising=False) # Ensure no API key in env 65 | monkeypatch.delenv("VLMRUN_BASE_URL", raising=False) # Ensure clean environment 66 | 67 | with pytest.raises(ConfigurationError) as exc_info: 68 | VLMRun() 69 | assert "Missing API key" in str(exc_info.value) 70 | 71 | 72 | def test_client_env_precedence(monkeypatch): 73 | """Test that constructor values take precedence over environment variables.""" 74 | monkeypatch.setenv("VLMRUN_API_KEY", "env-key") 75 | monkeypatch.setenv("VLMRUN_BASE_URL", "https://env.api") 76 | 77 | # Mock the health check request 78 | with patch("vlmrun.client.base_requestor.APIRequestor.request") as mock_request: 79 | mock_request.return_value = (None, 200, {}) 80 | client = VLMRun(api_key="test-key", base_url="https://custom.api") 81 | assert client.api_key == "test-key" # Constructor value 82 | assert client.base_url == "https://custom.api" # Constructor value 83 | 84 | 85 | @lru_cache # needs to be checked just once 86 | def _healthcheck(): 87 | return ( 88 | requests.get( 89 | os.getenv("VLMRUN_BASE_URL", "https://api.vlm.run/v1") + "/health" 90 | ).status_code 91 | == 200 92 | ) 93 | 94 | 95 | @pytest.mark.skipif( 96 | os.getenv("VLMRUN_API_KEY", None) is None 97 | and os.getenv("VLMRUN_BASE_URL", None) is None, 98 | reason="No VLMRUN_API_KEY and VLMRUN_BASE_URL in environment", 99 | ) 100 | @pytest.mark.skipif(not _healthcheck(), reason="API is not healthy") 101 | def test_client_health(): 102 | """Test client health check.""" 103 | client = VLMRun() 104 | assert client.healthcheck() 105 | assert len(client.models.list()) > 0, "No models found" 106 | 107 | 108 | @pytest.mark.skipif( 109 | os.getenv("VLMRUN_API_KEY", None) is None 110 | and os.getenv("VLMRUN_BASE_URL", None) is None, 111 | reason="No VLMRUN_API_KEY and VLMRUN_BASE_URL in environment", 112 | ) 113 | @pytest.mark.skipif(not _healthcheck(), reason="API is not healthy") 114 | def test_client_openai(): 115 | """Test client OpenAI integration.""" 116 | client = VLMRun() 117 | assert client.openai is not None 118 | assert client.openai.models.list() is not None 119 | 120 | 121 | def test_client_invalid_api_key(monkeypatch): 122 | """Test client initialization fails when API key is invalid.""" 123 | monkeypatch.delenv("VLMRUN_API_KEY", raising=False) 124 | monkeypatch.delenv("VLMRUN_BASE_URL", raising=False) # Ensure clean environment 125 | 126 | # Mock the health check request to simulate an invalid API key 127 | with patch("vlmrun.client.base_requestor.APIRequestor.request") as mock_request: 128 | mock_request.side_effect = AuthenticationError( 129 | message="Invalid API key", 130 | error_type="invalid_api_key", 131 | suggestion="Please check your API key and ensure it is valid. You can get your API key at https://app.vlm.run/dashboard", 132 | ) 133 | 134 | with pytest.raises(AuthenticationError) as exc_info: 135 | VLMRun(api_key="invalid-key") 136 | assert "Invalid API key" in str(exc_info.value) 137 | assert "Please check your API key" in exc_info.value.suggestion 138 | 139 | 140 | def test_client_valid_api_key(monkeypatch): 141 | """Test client initialization succeeds with valid API key.""" 142 | monkeypatch.delenv("VLMRUN_API_KEY", raising=False) 143 | monkeypatch.delenv("VLMRUN_BASE_URL", raising=False) # Ensure clean environment 144 | 145 | # Mock the health check request to simulate a valid API key 146 | with patch("vlmrun.client.base_requestor.APIRequestor.request") as mock_request: 147 | mock_request.return_value = (None, 200, {}) 148 | 149 | client = VLMRun(api_key="valid-key") 150 | assert client.api_key == "valid-key" 151 | assert ( 152 | client.base_url == "https://api.vlm.run/v1" 153 | ) # Default URL with /v1 suffix 154 | 155 | 156 | def test_client_health_check_failure(monkeypatch): 157 | """Test client initialization fails when health check returns non-200 status.""" 158 | monkeypatch.delenv("VLMRUN_API_KEY", raising=False) 159 | monkeypatch.delenv("VLMRUN_BASE_URL", raising=False) # Ensure clean environment 160 | 161 | # Mock the health check request to return a non-200 status 162 | with patch("vlmrun.client.base_requestor.APIRequestor.request") as mock_request: 163 | mock_request.return_value = (None, 401, {}) 164 | 165 | with pytest.raises(AuthenticationError) as exc_info: 166 | VLMRun(api_key="test-key") 167 | assert "Invalid API key" in str(exc_info.value) 168 | assert "Please check your API key" in exc_info.value.suggestion 169 | -------------------------------------------------------------------------------- /tests/test_exceptions.py: -------------------------------------------------------------------------------- 1 | """Tests for the exceptions module.""" 2 | 3 | import pytest 4 | from tenacity import RetryError 5 | 6 | from vlmrun.client.exceptions import ( 7 | VLMRunError, 8 | APIError, 9 | AuthenticationError, 10 | ValidationError, 11 | RateLimitError, 12 | ServerError, 13 | ResourceNotFoundError, 14 | ClientError, 15 | ConfigurationError, 16 | DependencyError, 17 | InputError, 18 | RequestTimeoutError, 19 | NetworkError, 20 | ) 21 | from vlmrun.client.base_requestor import APIRequestor 22 | 23 | 24 | def test_vlmrun_error(): 25 | """Test VLMRunError.""" 26 | error = VLMRunError("Test error") 27 | assert str(error) == "Test error" 28 | assert error.message == "Test error" 29 | 30 | 31 | def test_api_error(): 32 | """Test APIError.""" 33 | error = APIError( 34 | message="API error", 35 | http_status=400, 36 | request_id="req-123", 37 | error_type="test_error", 38 | suggestion="Try again", 39 | ) 40 | assert error.message == "API error" 41 | assert error.http_status == 400 42 | assert error.request_id == "req-123" 43 | assert error.headers == {} 44 | assert error.error_type == "test_error" 45 | assert error.suggestion == "Try again" 46 | assert "API error" in str(error) 47 | assert "status=400" in str(error) 48 | assert "type=test_error" in str(error) 49 | assert "id=req-123" in str(error) 50 | assert "Try again" in str(error) 51 | 52 | 53 | def test_authentication_error(): 54 | """Test AuthenticationError.""" 55 | error = AuthenticationError() 56 | assert error.message == "Authentication failed" 57 | assert error.http_status == 401 58 | assert error.error_type == "authentication_error" 59 | assert "Check your API key" in error.suggestion 60 | 61 | error = AuthenticationError(message="Invalid API key") 62 | assert error.message == "Invalid API key" 63 | assert error.http_status == 401 64 | 65 | 66 | def test_validation_error(): 67 | """Test ValidationError.""" 68 | error = ValidationError() 69 | assert error.message == "Validation failed" 70 | assert error.http_status == 400 71 | assert error.error_type == "validation_error" 72 | assert "Check your request parameters" in error.suggestion 73 | 74 | error = ValidationError(message="Invalid parameter") 75 | assert error.message == "Invalid parameter" 76 | assert error.http_status == 400 77 | 78 | 79 | def test_rate_limit_error(): 80 | """Test RateLimitError.""" 81 | error = RateLimitError() 82 | assert error.message == "Rate limit exceeded" 83 | assert error.http_status == 429 84 | assert error.error_type == "rate_limit_error" 85 | assert "Reduce request frequency" in error.suggestion 86 | 87 | error = RateLimitError(message="Too many requests") 88 | assert error.message == "Too many requests" 89 | assert error.http_status == 429 90 | 91 | 92 | def test_server_error(): 93 | """Test ServerError.""" 94 | error = ServerError() 95 | assert error.message == "Server error" 96 | assert error.http_status == 500 97 | assert error.error_type == "server_error" 98 | assert "try again later" in error.suggestion.lower() 99 | 100 | error = ServerError(message="Internal server error") 101 | assert error.message == "Internal server error" 102 | assert error.http_status == 500 103 | 104 | 105 | def test_resource_not_found_error(): 106 | """Test ResourceNotFoundError.""" 107 | error = ResourceNotFoundError() 108 | assert error.message == "Resource not found" 109 | assert error.http_status == 404 110 | assert error.error_type == "not_found_error" 111 | assert "Check the resource ID" in error.suggestion 112 | 113 | error = ResourceNotFoundError(message="Domain not found") 114 | assert error.message == "Domain not found" 115 | assert error.http_status == 404 116 | 117 | 118 | def test_client_error(): 119 | """Test ClientError.""" 120 | error = ClientError( 121 | message="Client error", 122 | error_type="test_error", 123 | suggestion="Fix your client", 124 | ) 125 | assert error.message == "Client error" 126 | assert error.error_type == "test_error" 127 | assert error.suggestion == "Fix your client" 128 | assert "Client error" in str(error) 129 | assert "type=test_error" in str(error) 130 | assert "Fix your client" in str(error) 131 | 132 | 133 | def test_configuration_error(): 134 | """Test ConfigurationError.""" 135 | error = ConfigurationError() 136 | assert error.message == "Invalid configuration" 137 | assert error.error_type == "configuration_error" 138 | assert "Check your client configuration" in error.suggestion 139 | 140 | error = ConfigurationError(message="Missing configuration") 141 | assert error.message == "Missing configuration" 142 | 143 | 144 | def test_dependency_error(): 145 | """Test DependencyError.""" 146 | error = DependencyError() 147 | assert error.message == "Missing dependency" 148 | assert error.error_type == "dependency_error" 149 | assert "Install the required dependency" in error.suggestion 150 | 151 | error = DependencyError(message="Missing OpenAI package") 152 | assert error.message == "Missing OpenAI package" 153 | 154 | 155 | def test_input_error(): 156 | """Test InputError.""" 157 | error = InputError() 158 | assert error.message == "Invalid input" 159 | assert error.error_type == "input_error" 160 | assert "Check your input parameters" in error.suggestion 161 | 162 | error = InputError(message="Invalid image format") 163 | assert error.message == "Invalid image format" 164 | 165 | 166 | def test_timeout_error(): 167 | """Test RequestTimeoutError.""" 168 | error = RequestTimeoutError() 169 | assert error.message == "Request timed out" 170 | assert error.http_status == 408 171 | assert error.error_type == "timeout_error" 172 | assert "Try again later" in error.suggestion 173 | 174 | error = RequestTimeoutError(message="Operation timed out") 175 | assert error.message == "Operation timed out" 176 | assert error.http_status == 408 177 | 178 | assert isinstance(error, APIError) 179 | 180 | 181 | def test_network_error(): 182 | """Test NetworkError.""" 183 | error = NetworkError() 184 | assert error.message == "Network error" 185 | assert error.error_type == "network_error" 186 | assert "Check your internet connection" in error.suggestion 187 | 188 | error = NetworkError(message="Connection failed") 189 | assert error.message == "Connection failed" 190 | 191 | assert isinstance(error, APIError) 192 | 193 | 194 | def test_retry_error_handling(): 195 | """Test that RetryError is properly handled and converts to appropriate exceptions.""" 196 | 197 | # Create a mock client 198 | class MockClient: 199 | def __init__(self): 200 | self.api_key = "test-key" 201 | self.base_url = "https://api.test" 202 | self.max_retries = 3 203 | 204 | client = MockClient() 205 | requestor = APIRequestor(client) 206 | 207 | # Test all our custom error types are preserved 208 | error_types = [ 209 | (AuthenticationError("Invalid API key"), "Invalid API key"), 210 | (ValidationError("Invalid parameter"), "Invalid parameter"), 211 | (RateLimitError("Too many requests"), "Too many requests"), 212 | (ResourceNotFoundError("Resource not found"), "Resource not found"), 213 | (ServerError("Server is down"), "Server is down"), 214 | (RequestTimeoutError("Request timed out"), "Request timed out"), 215 | (NetworkError("Connection failed"), "Connection failed"), 216 | ] 217 | 218 | for error, expected_message in error_types: 219 | retry_error = RetryError( 220 | last_attempt=type("obj", (object,), {"exception": lambda: error}) 221 | ) 222 | 223 | with pytest.raises(type(error)) as exc_info: 224 | requestor._handle_retry_error(retry_error) 225 | assert exc_info.value.message == expected_message 226 | 227 | # Test unknown error wrapping 228 | unknown_error = Exception("Unknown error") 229 | retry_error = RetryError( 230 | last_attempt=type("obj", (object,), {"exception": lambda: unknown_error}) 231 | ) 232 | 233 | with pytest.raises(APIError) as exc_info: 234 | requestor._handle_retry_error(retry_error) 235 | assert "Request failed after 3 retries" in exc_info.value.message 236 | assert "Unknown error" in exc_info.value.message 237 | -------------------------------------------------------------------------------- /tests/test_agent.py: -------------------------------------------------------------------------------- 1 | """Tests for the Agent resource.""" 2 | 3 | import pytest 4 | from datetime import datetime 5 | from vlmrun.client.types import ( 6 | AgentCreationResponse, 7 | AgentCreationConfig, 8 | AgentExecutionConfig, 9 | AgentExecutionResponse, 10 | AgentInfo, 11 | CreditUsage, 12 | ) 13 | 14 | 15 | class TestAgentCreationResponse: 16 | """Test the AgentCreationResponse model.""" 17 | 18 | def test_agent_creation_response_creation(self): 19 | """Test creating an AgentCreationResponse instance.""" 20 | response_data = { 21 | "id": "test-agent-123", 22 | "name": "test-agent:1.0.0", 23 | "created_at": datetime.now(), 24 | "updated_at": datetime.now(), 25 | "status": "completed", 26 | } 27 | 28 | response = AgentCreationResponse(**response_data) 29 | 30 | assert response.id == "test-agent-123" 31 | assert response.name == "test-agent:1.0.0" 32 | assert response.status == "completed" 33 | 34 | 35 | class TestAgentMethods: 36 | """Test the Agent API methods.""" 37 | 38 | def test_agent_get_by_name_and_version(self, mock_client): 39 | """Test getting an agent by name and version.""" 40 | client = mock_client 41 | response = client.agent.get(name="test-agent:1.0.0") 42 | 43 | assert isinstance(response, AgentInfo) 44 | assert response.name == "test-agent:1.0.0" 45 | assert response.description == "Test agent description" 46 | assert response.prompt == "Test agent prompt" 47 | assert response.status == "completed" 48 | 49 | def test_agent_get_by_id(self, mock_client): 50 | """Test getting an agent by ID.""" 51 | client = mock_client 52 | response = client.agent.get(id="agent-123") 53 | 54 | assert isinstance(response, AgentInfo) 55 | assert response.id == "agent-123" 56 | assert response.name == "agent-agent-123" 57 | 58 | def test_agent_get_validation_error(self, mock_client): 59 | """Test that get method validates input parameters.""" 60 | client = mock_client 61 | 62 | with pytest.raises( 63 | ValueError, match="Only one of `id` or `name` or `prompt` can be provided." 64 | ): 65 | client.agent.get(id="agent-123", name="test-agent") 66 | 67 | with pytest.raises( 68 | ValueError, match="Either `id` or `name` or `prompt` must be provided." 69 | ): 70 | client.agent.get() 71 | 72 | def test_agent_list(self, mock_client): 73 | """Test listing all agents.""" 74 | client = mock_client 75 | response = client.agent.list() 76 | 77 | assert isinstance(response, list) 78 | assert len(response) == 2 79 | assert all(isinstance(agent, AgentInfo) for agent in response) 80 | assert response[0].name.startswith("test-agent-1") 81 | assert response[1].name.startswith("test-agent-2") 82 | 83 | def test_agent_create(self, mock_client): 84 | """Test creating an agent.""" 85 | client = mock_client 86 | config = AgentCreationConfig( 87 | prompt="Create a test agent", 88 | json_schema={ 89 | "type": "object", 90 | "properties": {"result": {"type": "string"}}, 91 | }, 92 | ) 93 | 94 | response = client.agent.create( 95 | config=config, name="new-agent", inputs={"test": "input"} 96 | ) 97 | 98 | assert isinstance(response, AgentCreationResponse) 99 | assert response.name == "new-agent" 100 | assert response.status == "pending" 101 | 102 | def test_agent_create_validation_error(self, mock_client): 103 | """Test that create method validates config has prompt.""" 104 | client = mock_client 105 | config = AgentCreationConfig() 106 | 107 | with pytest.raises( 108 | ValueError, 109 | match="Prompt is not provided as a request parameter, please provide a prompt.", 110 | ): 111 | client.agent.create(config=config) 112 | 113 | def test_agent_execute(self, mock_client): 114 | """Test executing an agent.""" 115 | client = mock_client 116 | config = AgentExecutionConfig( 117 | prompt="Execute this agent", 118 | json_schema={ 119 | "type": "object", 120 | "properties": {"output": {"type": "string"}}, 121 | }, 122 | ) 123 | 124 | response = client.agent.execute( 125 | name="test-agent:1.0.0", 126 | inputs={"input": "test data"}, 127 | config=config, 128 | ) 129 | 130 | assert isinstance(response, AgentExecutionResponse) 131 | assert response.name == "test-agent:1.0.0" 132 | assert response.status == "completed" 133 | assert response.response == {"result": "execution result"} 134 | assert isinstance(response.usage, CreditUsage) 135 | assert response.usage.credits_used == 50 136 | 137 | def test_agent_execute_without_version(self, mock_client): 138 | """Test executing an agent without specifying version.""" 139 | client = mock_client 140 | 141 | response = client.agent.execute(name="test-agent") 142 | 143 | assert isinstance(response, AgentExecutionResponse) 144 | assert response.name == "test-agent" 145 | 146 | def test_agent_execute_batch_mode_required(self, mock_client): 147 | """Test that execute method requires batch mode.""" 148 | client = mock_client 149 | 150 | with pytest.raises( 151 | NotImplementedError, match="Batch mode is required for agent execution" 152 | ): 153 | client.agent.execute(name="test-agent", batch=False) 154 | 155 | 156 | class TestAgentConfigModels: 157 | """Test the Agent configuration models.""" 158 | 159 | def test_agent_creation_config(self): 160 | """Test AgentCreationConfig model.""" 161 | config = AgentCreationConfig( 162 | prompt="Test prompt", 163 | json_schema={ 164 | "type": "object", 165 | "properties": {"result": {"type": "string"}}, 166 | }, 167 | ) 168 | 169 | assert config.prompt == "Test prompt" 170 | assert config.json_schema is not None 171 | assert config.response_model is None 172 | 173 | def test_agent_execution_config(self): 174 | """Test AgentExecutionConfig model.""" 175 | config = AgentExecutionConfig( 176 | prompt="Execute prompt", 177 | json_schema={ 178 | "type": "object", 179 | "properties": {"output": {"type": "string"}}, 180 | }, 181 | ) 182 | 183 | assert config.prompt == "Execute prompt" 184 | assert config.json_schema is not None 185 | assert config.response_model is None 186 | 187 | def test_config_validation_error(self): 188 | """Test that config validates response_model and json_schema are not both provided.""" 189 | from pydantic import BaseModel 190 | 191 | class TestModel(BaseModel): 192 | result: str 193 | 194 | with pytest.raises(ValueError, match="cannot be used together"): 195 | AgentCreationConfig( 196 | prompt="Test", response_model=TestModel, json_schema={"type": "object"} 197 | ) 198 | 199 | 200 | class TestAgentCompletions: 201 | """Test the Agent OpenAI completions integration.""" 202 | 203 | def test_agent_completions_property_exists(self, mock_client): 204 | """Test that agent.completions property exists and returns OpenAI Completions object.""" 205 | client = mock_client 206 | 207 | completions = client.agent.completions 208 | 209 | assert hasattr(completions, "create") 210 | assert hasattr(completions, "stream") 211 | assert callable(completions.create) 212 | 213 | def test_agent_async_completions_property_exists(self, mock_client): 214 | """Test that agent.async_completions property exists and returns AsyncCompletions object.""" 215 | client = mock_client 216 | 217 | async_completions = client.agent.async_completions 218 | 219 | assert hasattr(async_completions, "create") 220 | assert callable(async_completions.create) 221 | 222 | def test_agent_completions_cached_property(self, mock_client): 223 | """Test that completions property is cached.""" 224 | client = mock_client 225 | 226 | completions1 = client.agent.completions 227 | completions2 = client.agent.completions 228 | 229 | assert completions1 is completions2 230 | 231 | def test_agent_async_completions_cached_property(self, mock_client): 232 | """Test that async_completions property is cached.""" 233 | client = mock_client 234 | 235 | async_completions1 = client.agent.async_completions 236 | async_completions2 = client.agent.async_completions 237 | 238 | assert async_completions1 is async_completions2 239 | -------------------------------------------------------------------------------- /vlmrun/common/video.py: -------------------------------------------------------------------------------- 1 | """Video utilities for reading and writing video files using OpenCV.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from pathlib import Path 5 | from typing import Callable, Iterator, List, Optional, Union 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | 11 | T = np.ndarray 12 | 13 | 14 | class BaseVideoReader(ABC): 15 | """Abstract base class for video readers.""" 16 | 17 | def __init__(self, filename: Union[str, Path]): 18 | """Initialize the base video reader. 19 | 20 | Args: 21 | filename (Union[str, Path]): Path to the video file. 22 | """ 23 | self.filename = Path(str(filename)) 24 | self._video = None 25 | 26 | def __repr__(self) -> str: 27 | """Return a string representation of the video reader. 28 | 29 | Returns: 30 | str: A string representation of the video reader. 31 | """ 32 | return f"{self.__class__.__name__}(filename={self.filename})" 33 | 34 | @abstractmethod 35 | def __len__(self) -> int: 36 | """Return the number of frames in the video. 37 | 38 | Returns: 39 | int: The number of frames in the video. 40 | """ 41 | raise NotImplementedError() 42 | 43 | @abstractmethod 44 | def __iter__(self) -> Iterator[T]: 45 | """Return an iterator over the video frames. 46 | 47 | Returns: 48 | Iterator[T]: An iterator over the video frames. 49 | """ 50 | raise NotImplementedError() 51 | 52 | @abstractmethod 53 | def __next__(self) -> T: 54 | """Return the next frame in the video. 55 | 56 | Returns: 57 | T: The next frame in the video. 58 | 59 | Raises: 60 | StopIteration: If there are no more frames in the video. 61 | """ 62 | raise NotImplementedError() 63 | 64 | @abstractmethod 65 | def __getitem__(self, idx: Union[int, List[int]]) -> Union[T, List[T]]: 66 | """Return the frame(s) at the given index/indices. 67 | 68 | Args: 69 | idx (Union[int, List[int]]): The index or list of indices to retrieve. 70 | 71 | Returns: 72 | Union[T, List[T]]: The frame or list of frames at the given index/indices. 73 | 74 | Raises: 75 | IndexError: If any index is out of bounds. 76 | TypeError: If the index type is not supported. 77 | """ 78 | raise NotImplementedError() 79 | 80 | @abstractmethod 81 | def open(self) -> None: 82 | """Open the video file. 83 | 84 | Raises: 85 | FileNotFoundError: If the video file does not exist. 86 | RuntimeError: If the video file cannot be opened. 87 | """ 88 | raise NotImplementedError() 89 | 90 | @abstractmethod 91 | def close(self) -> None: 92 | """Close the video file.""" 93 | raise NotImplementedError() 94 | 95 | @abstractmethod 96 | def pos(self) -> Optional[int]: 97 | """Return the current position in the video. 98 | 99 | Returns: 100 | Optional[int]: The current frame position, or None if position cannot be determined. 101 | """ 102 | raise NotImplementedError() 103 | 104 | @abstractmethod 105 | def seek(self, idx: int) -> None: 106 | """Seek to the given frame index in the video. 107 | 108 | Args: 109 | idx (int): The frame index to seek to. 110 | 111 | Raises: 112 | IndexError: If the index is out of bounds. 113 | """ 114 | raise NotImplementedError() 115 | 116 | def reset(self) -> None: 117 | """Reset the video reader to the beginning.""" 118 | self.seek(0) 119 | 120 | def __enter__(self): 121 | """Enter the context manager.""" 122 | return self 123 | 124 | def __exit__(self, exc_type, exc_value, traceback): 125 | """Exit the context manager.""" 126 | self.close() 127 | 128 | 129 | class VideoReader(BaseVideoReader): 130 | """Video reader implementation using OpenCV.""" 131 | 132 | def __init__( 133 | self, filename: Union[str, Path], transform: Optional[Callable] = None 134 | ): 135 | """Initialize the video reader. 136 | 137 | Args: 138 | filename (Union[str, Path]): Path to the video file. 139 | transform (Optional[Callable], optional): Optional transform to apply to each frame. 140 | Defaults to None. 141 | 142 | Raises: 143 | FileNotFoundError: If the video file does not exist. 144 | """ 145 | super().__init__(filename) 146 | if not self.filename.exists(): 147 | raise FileNotFoundError(f"{self.filename} does not exist") 148 | self.transform = transform 149 | self._video = self.open() 150 | 151 | def __len__(self) -> int: 152 | """Return the number of frames in the video. 153 | 154 | Returns: 155 | int: The number of frames in the video. 156 | """ 157 | if self._video is None: 158 | return 0 159 | return int(self._video.get(cv2.CAP_PROP_FRAME_COUNT)) 160 | 161 | def __iter__(self) -> Iterator[T]: 162 | """Return an iterator over the video frames. 163 | 164 | Returns: 165 | Iterator[T]: An iterator over the video frames. 166 | """ 167 | return self 168 | 169 | def __next__(self) -> T: 170 | """Return the next frame in the video. 171 | 172 | Returns: 173 | T: The next frame in the video. 174 | 175 | Raises: 176 | StopIteration: If there are no more frames in the video. 177 | RuntimeError: If the video is not opened. 178 | """ 179 | if self._video is None: 180 | raise RuntimeError("Video is not opened") 181 | ret, frame = self._video.read() 182 | if not ret: 183 | raise StopIteration() 184 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 185 | if self.transform: 186 | frame = self.transform(frame) 187 | return frame 188 | 189 | def __getitem__(self, idx: Union[int, List[int]]) -> Union[T, List[T]]: 190 | """Return the frame(s) at the given index/indices. 191 | 192 | Args: 193 | idx (Union[int, List[int]]): The index or list of indices to retrieve. 194 | 195 | Returns: 196 | Union[T, List[T]]: The frame or list of frames at the given index/indices. 197 | 198 | Raises: 199 | IndexError: If any index is out of bounds. 200 | TypeError: If the index type is not supported. 201 | """ 202 | if isinstance(idx, int): 203 | self.seek(idx) 204 | return next(self) 205 | elif isinstance(idx, list): 206 | frames = [] 207 | for i in idx: 208 | frames.append(self.__getitem__(i)) 209 | return frames 210 | else: 211 | raise TypeError(f"Unsupported index type: {type(idx)}") 212 | 213 | def open(self) -> cv2.VideoCapture: 214 | """Open the video file. 215 | 216 | Returns: 217 | cv2.VideoCapture: The opened video capture object. 218 | 219 | Raises: 220 | RuntimeError: If the video file cannot be opened. 221 | """ 222 | video = cv2.VideoCapture(str(self.filename)) 223 | if not video.isOpened(): 224 | raise RuntimeError(f"Failed to open video file: {self.filename}") 225 | return video 226 | 227 | def close(self) -> None: 228 | """Close the video file.""" 229 | if self._video is not None: 230 | self._video.release() 231 | self._video = None 232 | 233 | def pos(self) -> Optional[int]: 234 | """Return the current position in the video. 235 | 236 | Returns: 237 | Optional[int]: The current frame position, or None if position cannot be determined. 238 | """ 239 | if self._video is None: 240 | return None 241 | try: 242 | return int(self._video.get(cv2.CAP_PROP_POS_FRAMES)) 243 | except Exception: 244 | return None 245 | 246 | def seek(self, idx: int) -> None: 247 | """Seek to the given frame index in the video. 248 | 249 | Args: 250 | idx (int): The frame index to seek to. 251 | 252 | Raises: 253 | IndexError: If the index is out of bounds. 254 | RuntimeError: If the video is not opened. 255 | """ 256 | if self._video is None: 257 | raise RuntimeError("Video is not opened") 258 | if idx < 0 or idx >= len(self): 259 | raise IndexError(f"Frame index out of bounds: {idx}") 260 | self._video.set(cv2.CAP_PROP_POS_FRAMES, idx) 261 | 262 | 263 | class VideoWriter: 264 | """Video writer implementation using OpenCV.""" 265 | 266 | def __init__(self, filename: Union[str, Path], fps: float = 30.0): 267 | """Initialize the video writer. 268 | 269 | Args: 270 | filename (Union[str, Path]): Path to the output video file. 271 | fps (float, optional): Frames per second. Defaults to 30.0. 272 | 273 | Raises: 274 | FileExistsError: If the output file already exists. 275 | """ 276 | self.filename = Path(str(filename)) 277 | if self.filename.exists(): 278 | raise FileExistsError(f"Output file already exists: {self.filename}") 279 | self.fps = fps 280 | self.writer = None 281 | 282 | def write(self, frame: np.ndarray) -> None: 283 | """Write a frame to the video. 284 | 285 | Args: 286 | frame (np.ndarray): The frame to write. Should be an RGB image. 287 | """ 288 | if self.writer is None: 289 | height, width = frame.shape[:2] 290 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 291 | self.writer = cv2.VideoWriter( 292 | str(self.filename), fourcc, self.fps, (width, height), frame.ndim == 3 293 | ) 294 | 295 | # Convert RGB to BGR for OpenCV 296 | if frame.ndim == 3: 297 | frame = frame[..., ::-1] 298 | self.writer.write(frame) 299 | 300 | def close(self) -> None: 301 | """Close the video writer.""" 302 | if self.writer is not None: 303 | self.writer.release() 304 | self.writer = None 305 | 306 | def __enter__(self): 307 | """Enter the context manager.""" 308 | return self 309 | 310 | def __exit__(self, exc_type, exc_value, traceback): 311 | """Exit the context manager.""" 312 | self.close() 313 | -------------------------------------------------------------------------------- /vlmrun/client/agent.py: -------------------------------------------------------------------------------- 1 | """VLM Run API Agent resource.""" 2 | 3 | from __future__ import annotations 4 | from functools import cached_property 5 | from typing import Any, Optional 6 | 7 | from vlmrun.client.base_requestor import APIRequestor 8 | from vlmrun.types.abstract import VLMRunProtocol 9 | from vlmrun.client.types import ( 10 | RequestMetadata, 11 | AgentInfo, 12 | AgentExecutionResponse, 13 | AgentExecutionConfig, 14 | AgentCreationConfig, 15 | AgentCreationResponse, 16 | ) 17 | from vlmrun.client.exceptions import DependencyError 18 | 19 | 20 | class Agent: 21 | """Agent resource for VLM Run API.""" 22 | 23 | def __init__(self, client: "VLMRunProtocol") -> None: 24 | """Initialize Agent resource with VLMRun instance. 25 | 26 | Args: 27 | client: VLM Run API instance 28 | """ 29 | self._client = client 30 | self._requestor = APIRequestor(client) 31 | 32 | def get( 33 | self, 34 | name: str | None = None, 35 | id: str | None = None, 36 | prompt: str | None = None, 37 | ) -> AgentInfo: 38 | """Get an agent by name, id, or prompt. Only one of `name`, `id`, or `prompt` can be provided. 39 | 40 | Args: 41 | name: Name of the agent 42 | id: ID of the agent 43 | prompt: Prompt of the agent 44 | 45 | Raises: 46 | APIError: If the agent is not found (404) or the agent name is invalid (400) 47 | 48 | Returns: 49 | AgentInfo: Agent information response 50 | """ 51 | if id: 52 | if name or prompt: 53 | raise ValueError( 54 | "Only one of `id` or `name` or `prompt` can be provided." 55 | ) 56 | data = {"id": id} 57 | elif name: 58 | if id or prompt: 59 | raise ValueError( 60 | "Only one of `id` or `name` or `prompt` can be provided." 61 | ) 62 | data = {"name": name} 63 | elif prompt: 64 | if id or name: 65 | raise ValueError( 66 | "Only one of `id` or `name` or `prompt` can be provided." 67 | ) 68 | data = {"prompt": prompt} 69 | else: 70 | raise ValueError("Either `id` or `name` or `prompt` must be provided.") 71 | 72 | response, status_code, headers = self._requestor.request( 73 | method="POST", 74 | url="agent/lookup", 75 | data=data, 76 | ) 77 | 78 | if not isinstance(response, dict): 79 | raise TypeError("Expected dict response") 80 | 81 | return AgentInfo(**response) 82 | 83 | def list(self) -> list[AgentInfo]: 84 | """List all agents.""" 85 | response, status_code, headers = self._requestor.request( 86 | method="GET", 87 | url="agent", 88 | ) 89 | 90 | if not isinstance(response, list): 91 | raise TypeError("Expected list response") 92 | 93 | return [AgentInfo(**agent) for agent in response] 94 | 95 | def create( 96 | self, 97 | config: AgentCreationConfig, 98 | name: str | None = None, 99 | inputs: Optional[dict[str, Any]] = None, 100 | callback_url: Optional[str] = None, 101 | ) -> AgentCreationResponse: 102 | """Create an agent. 103 | 104 | Args: 105 | config: Agent creation configuration 106 | name: Optional name of the agent to create 107 | inputs: Optional inputs to the agent (e.g. {"image": "https://..."}) 108 | callback_url: Optional URL to call when creation is complete 109 | 110 | Returns: 111 | AgentCreationResponse: Agent creation response 112 | """ 113 | if config.prompt is None: 114 | raise ValueError( 115 | "Prompt is not provided as a request parameter, please provide a prompt." 116 | ) 117 | 118 | data = { 119 | "name": name, 120 | "inputs": inputs, 121 | "config": config.model_dump(), 122 | } 123 | 124 | if callback_url: 125 | data["callback_url"] = callback_url 126 | 127 | response, status_code, headers = self._requestor.request( 128 | method="POST", 129 | url="agent/create", 130 | data=data, 131 | ) 132 | 133 | if not isinstance(response, dict): 134 | raise TypeError("Expected dict response") 135 | 136 | return AgentCreationResponse(**response) 137 | 138 | def execute( 139 | self, 140 | name: str | None = None, 141 | inputs: Optional[dict[str, Any]] = None, 142 | batch: bool = True, 143 | config: Optional[AgentExecutionConfig] = None, 144 | metadata: Optional[RequestMetadata] = None, 145 | callback_url: Optional[str] = None, 146 | ) -> AgentExecutionResponse: 147 | """Execute an agent with the given arguments. 148 | 149 | Args: 150 | name: Name of the agent to execute. If not provided, we use the prompt to identify the unique agent. 151 | inputs: Optional inputs to the agent 152 | batch: Whether to process in batch mode (async) 153 | config: Optional agent execution configuration 154 | metadata: Optional request metadata 155 | callback_url: Optional URL to call when execution is complete 156 | 157 | Returns: 158 | AgentExecutionResponse: Agent execution response 159 | """ 160 | if not batch: 161 | raise NotImplementedError("Batch mode is required for agent execution") 162 | 163 | data = { 164 | "name": name, 165 | "batch": batch, 166 | "inputs": inputs, 167 | } 168 | 169 | if config: 170 | data["config"] = config.model_dump() 171 | 172 | if metadata: 173 | data["metadata"] = metadata.model_dump() 174 | 175 | if callback_url: 176 | data["callback_url"] = callback_url 177 | 178 | response, status_code, headers = self._requestor.request( 179 | method="POST", 180 | url="agent/execute", 181 | data=data, 182 | ) 183 | 184 | if not isinstance(response, dict): 185 | raise TypeError("Expected dict response") 186 | 187 | return AgentExecutionResponse(**response) 188 | 189 | def get_by_id(self, agent_id: str) -> AgentInfo: 190 | """Get agent information by ID. 191 | 192 | Args: 193 | agent_id: The ID of the agent to retrieve 194 | 195 | Returns: 196 | AgentInfo: Information about the agent 197 | """ 198 | response, status_code, headers = self._requestor.request( 199 | method="GET", 200 | url=f"agents/{agent_id}", 201 | ) 202 | 203 | if not isinstance(response, dict): 204 | raise TypeError("Expected dict response") 205 | 206 | return AgentInfo(**response) 207 | 208 | @cached_property 209 | def completions(self): 210 | """OpenAI-compatible chat completions interface (synchronous). 211 | 212 | Returns an OpenAI Completions object configured to use the VLMRun 213 | agent endpoint. This allows you to use the familiar OpenAI API 214 | for chat completions. 215 | 216 | Example: 217 | ```python 218 | from vlmrun import VLMRun 219 | 220 | client = VLMRun(api_key="your-key", base_url="https://agent.vlm.run/v1") 221 | 222 | response = client.agent.completions.create( 223 | model="vlmrun-orion-1", 224 | messages=[ 225 | {"role": "user", "content": "Hello!"} 226 | ] 227 | ) 228 | ``` 229 | 230 | Raises: 231 | DependencyError: If openai package is not installed 232 | 233 | Returns: 234 | OpenAI Completions object configured for VLMRun agent endpoint 235 | """ 236 | try: 237 | from openai import OpenAI 238 | except ImportError: 239 | raise DependencyError( 240 | message="OpenAI SDK is not installed", 241 | suggestion="Install it with `pip install vlmrun[openai]` or `pip install openai`", 242 | error_type="missing_dependency", 243 | ) 244 | 245 | base_url = f"{self._client.base_url}/openai" 246 | openai_client = OpenAI( 247 | api_key=self._client.api_key, 248 | base_url=base_url, 249 | timeout=self._client.timeout, 250 | max_retries=self._client.max_retries, 251 | ) 252 | 253 | return openai_client.chat.completions 254 | 255 | @cached_property 256 | def async_completions(self): 257 | """OpenAI-compatible chat completions interface (asynchronous). 258 | 259 | Returns an OpenAI AsyncCompletions object configured to use the VLMRun 260 | agent endpoint. This allows you to use the familiar OpenAI async API 261 | for chat completions. 262 | 263 | Example: 264 | ```python 265 | from vlmrun import VLMRun 266 | import asyncio 267 | 268 | client = VLMRun(api_key="your-key", base_url="https://agent.vlm.run/v1") 269 | 270 | async def main(): 271 | response = await client.agent.async_completions.create( 272 | model="vlmrun-orion-1", 273 | messages=[ 274 | {"role": "user", "content": "Hello!"} 275 | ] 276 | ) 277 | return response 278 | 279 | asyncio.run(main()) 280 | ``` 281 | 282 | Raises: 283 | DependencyError: If openai package is not installed 284 | 285 | Returns: 286 | OpenAI AsyncCompletions object configured for VLMRun agent endpoint 287 | """ 288 | try: 289 | from openai import AsyncOpenAI 290 | except ImportError: 291 | raise DependencyError( 292 | message="OpenAI SDK is not installed", 293 | suggestion="Install it with `pip install vlmrun[openai]` or `pip install openai`", 294 | error_type="missing_dependency", 295 | ) 296 | 297 | base_url = f"{self._client.base_url}/openai" 298 | async_openai_client = AsyncOpenAI( 299 | api_key=self._client.api_key, 300 | base_url=base_url, 301 | timeout=self._client.timeout, 302 | max_retries=self._client.max_retries, 303 | ) 304 | 305 | return async_openai_client.chat.completions 306 | -------------------------------------------------------------------------------- /vlmrun/client/files.py: -------------------------------------------------------------------------------- 1 | """VLM Run API Files resource.""" 2 | 3 | from __future__ import annotations 4 | 5 | import requests 6 | import time 7 | import hashlib 8 | from pathlib import Path 9 | from typing import Optional, Union, List, Literal 10 | 11 | from loguru import logger 12 | from vlmrun.client.base_requestor import APIRequestor 13 | from vlmrun.types.abstract import VLMRunProtocol 14 | from vlmrun.client.types import FileResponse, PresignedUrlRequest, PresignedUrlResponse 15 | 16 | 17 | class Files: 18 | """Files resource for VLM Run API.""" 19 | 20 | def __init__(self, client: "VLMRunProtocol") -> None: 21 | """Initialize Files resource with VLMRun instance. 22 | 23 | Args: 24 | client: VLM Run API instance 25 | """ 26 | self._client = client 27 | self._requestor = APIRequestor(client) 28 | 29 | def list(self, skip: int = 0, limit: int = 10) -> List[FileResponse]: 30 | """List all files. 31 | 32 | Args: 33 | skip: Number of items to skip 34 | limit: Maximum number of items to return 35 | 36 | Returns: 37 | List[FileResponse]: List of file objects 38 | """ 39 | response, status_code, headers = self._requestor.request( 40 | method="GET", 41 | url="files", 42 | params={"skip": skip, "limit": limit}, 43 | ) 44 | return [FileResponse(**file) for file in response] 45 | 46 | def get_cached_file(self, file: Union[Path, str]) -> Optional[FileResponse]: 47 | """Get a cached file from the database, if it exists. 48 | 49 | Args: 50 | file: Path to file to check 51 | 52 | Returns: 53 | FileResponse: File object if it exists, None otherwise 54 | """ 55 | if isinstance(file, str): 56 | file = Path(file) 57 | # Compute the md5 hash of the file 58 | logger.debug(f"Computing md5 hash for file [file={file}]") 59 | file_hash = hashlib.md5() 60 | with file.open("rb") as f: 61 | while chunk := f.read(4 * 1024 * 1024): 62 | file_hash.update(chunk) 63 | file_hash = file_hash.hexdigest() 64 | logger.debug(f"Computed md5 hash for file [file={file}, hash={file_hash}]") 65 | 66 | # Check if the file exists in the database 67 | logger.debug( 68 | f"Checking if file exists in the database [file={file}, hash={file_hash}]" 69 | ) 70 | try: 71 | response, status_code, headers = self._requestor.request( 72 | method="GET", 73 | url=f"files/hash/{file_hash}", 74 | ) 75 | if status_code == 200: 76 | file_response = FileResponse(**response) 77 | # Check if the file exists by checking if id is empty 78 | logger.debug(f"File response [file_response={file_response.id}]") 79 | if file_response.id: 80 | return file_response 81 | else: 82 | return None 83 | else: 84 | return None 85 | except Exception as exc: 86 | logger.debug( 87 | f"File hash does not exist in the database [file={file}, hash={file_hash}, exc={exc}]" 88 | ) 89 | return None 90 | 91 | def upload( 92 | self, 93 | file: Union[Path, str], 94 | purpose: str = "assistants", 95 | method: Literal["auto", "direct", "presigned-url"] = "auto", 96 | timeout: int = 3 * 60, 97 | expiration: int = 24 * 60 * 60, 98 | force: bool = False, 99 | ) -> FileResponse: 100 | """Upload a file. 101 | 102 | Args: 103 | file: Path to file to upload 104 | purpose: Purpose of file (default: fine-tune) 105 | timeout: Timeout for upload (default: 3 minutes) 106 | method: Method to use for upload (default: "auto") 107 | expiration: Expiration time for presigned URL (default: 24 hours) 108 | force: Force upload even if file already exists (default: False) 109 | 110 | Returns: 111 | FileResponse: Uploaded file object 112 | """ 113 | if isinstance(file, str): 114 | try: 115 | file = Path(file) 116 | except Exception as exc: 117 | raise ValueError(f"Invalid file path: {file}") from exc 118 | elif not isinstance(file, Path): 119 | raise ValueError(f"Invalid file path: {file}") 120 | 121 | # Check if the file exists 122 | if not file.exists(): 123 | raise FileNotFoundError(f"File does not exist: {file}") 124 | 125 | # Check if the file already exists in the database 126 | cached_response: Optional[FileResponse] = self.get_cached_file(file) 127 | if cached_response and not force: 128 | return cached_response 129 | 130 | # If method is "auto", check if the file is too large to upload directly 131 | if method == "auto": 132 | if file.stat().st_size > 32 * 1024 * 1024: # 32MB 133 | method = "presigned-url" 134 | else: 135 | method = "direct" 136 | logger.debug(f"Upload method [file={file}, method={method}]") 137 | 138 | # If method is "presigned-url", generate a presigned URL for the file 139 | if method == "presigned-url": 140 | # Generate a presigned URL for the file 141 | response, status_code, headers = self._requestor.request( 142 | method="POST", 143 | url="files/presigned-url", 144 | data={ 145 | "filename": file.name, 146 | "purpose": purpose, 147 | "expiration": expiration, 148 | }, 149 | ) 150 | if not isinstance(response, dict): 151 | raise TypeError("Expected dict response") 152 | response = PresignedUrlResponse(**response) 153 | 154 | # PUT the file to the presigned URL 155 | start_t = time.time() 156 | logger.debug( 157 | f"Uploading file to presigned URL [file={file}, id={response.id}, url={response.url}]" 158 | ) 159 | with file.open("rb") as f: 160 | put_response = requests.put( 161 | response.url, 162 | headers={"Content-Type": response.content_type}, 163 | data=f, 164 | ) 165 | status_code = put_response.status_code 166 | end_t = time.time() 167 | logger.debug( 168 | f"Uploaded file to presigned URL [file={file}, url={response.url}, time={end_t - start_t:.1f}s]" 169 | ) 170 | if status_code == 200: 171 | # Verify the file upload 172 | verify_response, status_code, headers = self._requestor.request( 173 | method="GET", 174 | url=f"files/verify-upload/{response.id}", 175 | ) 176 | if status_code == 200: 177 | return FileResponse(**verify_response) 178 | else: 179 | raise Exception(f"Failed to verify file upload: {verify_response}") 180 | else: 181 | raise Exception(f"Failed to upload file to presigned URL: {response}") 182 | 183 | # If method is "direct", upload the file directly 184 | elif method == "direct": 185 | logger.debug(f"Uploading file directly [file={file}]") 186 | # Upload the file 187 | with open(file, "rb") as f: 188 | files = {"file": (file.name, f)} 189 | response, status_code, headers = self._requestor.request( 190 | method="POST", 191 | url="files", 192 | params={"purpose": purpose}, 193 | files=files, 194 | timeout=timeout, 195 | ) 196 | if status_code == 201: 197 | if not isinstance(response, dict): 198 | raise TypeError("Expected dict response") 199 | return FileResponse(**response) 200 | else: 201 | raise Exception(f"Failed to upload file directly: {response}") 202 | 203 | else: 204 | raise ValueError(f"Invalid upload method: {method}") 205 | 206 | def get(self, file_id: str) -> FileResponse: 207 | """Get file metadata. 208 | 209 | Args: 210 | file_id: ID of file to retrieve 211 | 212 | Returns: 213 | FileResponse: File metadata 214 | """ 215 | response, status_code, headers = self._requestor.request( 216 | method="GET", 217 | url=f"files/{file_id}", 218 | ) 219 | 220 | if not isinstance(response, dict): 221 | raise TypeError("Expected dict response") 222 | return FileResponse(**response) 223 | 224 | def get_content(self, file_id: str) -> bytes: 225 | """Get file content. 226 | 227 | Args: 228 | file_id: ID of file to retrieve content for 229 | 230 | Returns: 231 | bytes: File content 232 | """ 233 | raise NotImplementedError("Not implemented") 234 | 235 | def generate_presigned_url( 236 | self, params: PresignedUrlRequest 237 | ) -> PresignedUrlResponse: 238 | """Generate a presigned URL for file upload. 239 | 240 | Args: 241 | params: PresignedUrlRequest containing filename and optional purpose 242 | 243 | Returns: 244 | PresignedUrlResponse: Presigned URL response with upload URL and metadata 245 | """ 246 | response, status_code, headers = self._requestor.request( 247 | method="POST", 248 | url="files/presigned-url", 249 | data={ 250 | "filename": params.filename, 251 | "purpose": params.purpose, 252 | }, 253 | ) 254 | 255 | if not isinstance(response, dict): 256 | raise TypeError("Expected dict response") 257 | return PresignedUrlResponse(**response) 258 | 259 | def delete(self, file_id: str) -> FileResponse: 260 | """Delete a file. 261 | 262 | Args: 263 | file_id: ID of file to delete 264 | 265 | Returns: 266 | FileResponse: Deletion confirmation 267 | """ 268 | response, status_code, headers = self._requestor.request( 269 | method="DELETE", 270 | url=f"files/{file_id}", 271 | ) 272 | 273 | if not isinstance(response, dict): 274 | raise TypeError("Expected dict response") 275 | return FileResponse(**response) 276 | -------------------------------------------------------------------------------- /vlmrun/client/base_requestor.py: -------------------------------------------------------------------------------- 1 | """VLM Run API requestor implementation.""" 2 | 3 | from typing import Any, Dict, Tuple, TYPE_CHECKING, Union, Optional 4 | from urllib.parse import urljoin 5 | 6 | if TYPE_CHECKING: 7 | from vlmrun.types.abstract import VLMRunProtocol 8 | 9 | from vlmrun.version import __version__ 10 | 11 | import requests 12 | from tenacity import ( 13 | retry, 14 | retry_if_exception_type, 15 | stop_after_attempt, 16 | wait_exponential, 17 | RetryError, 18 | ) 19 | 20 | from vlmrun.client.exceptions import ( 21 | APIError, 22 | AuthenticationError, 23 | ValidationError, 24 | RateLimitError, 25 | ServerError, 26 | ResourceNotFoundError, 27 | RequestTimeoutError, 28 | NetworkError, 29 | ) 30 | 31 | # Constants 32 | DEFAULT_TIMEOUT = 30.0 # seconds 33 | DEFAULT_MAX_RETRIES = 5 34 | INITIAL_RETRY_DELAY = 1 # seconds 35 | MAX_RETRY_DELAY = 10 # seconds 36 | 37 | 38 | class APIRequestor: 39 | """Handles API requests with retry logic.""" 40 | 41 | def __init__( 42 | self, 43 | client: "VLMRunProtocol", 44 | base_url: Optional[str] = None, 45 | timeout: float = DEFAULT_TIMEOUT, 46 | max_retries: Optional[int] = None, 47 | ) -> None: 48 | """Initialize API requestor. 49 | 50 | Args: 51 | client: VLMRun API instance 52 | base_url: Base URL for API 53 | timeout: Request timeout in seconds 54 | max_retries: Maximum number of retry attempts 55 | """ 56 | self._client = client 57 | self._base_url = base_url or client.base_url 58 | self._timeout = timeout 59 | self._max_retries = ( 60 | max_retries 61 | if max_retries is not None 62 | else getattr(client, "max_retries", DEFAULT_MAX_RETRIES) 63 | ) 64 | self._session = requests.Session() 65 | 66 | def request( 67 | self, 68 | method: str, 69 | url: str, 70 | params: Optional[Dict[str, Any]] = None, 71 | data: Optional[Dict[str, Any]] = None, 72 | files: Optional[Dict[str, Any]] = None, 73 | headers: Optional[Dict[str, str]] = None, 74 | raw_response: bool = False, 75 | timeout: Optional[float] = None, 76 | ) -> Union[ 77 | Tuple[Dict[str, Any], int, Dict[str, str]], Tuple[bytes, int, Dict[str, str]] 78 | ]: 79 | """Make an API request with retry logic. 80 | 81 | Args: 82 | method: HTTP method 83 | url: API endpoint 84 | params: Query parameters 85 | data: Request body 86 | files: Files to upload 87 | headers: Request headers 88 | raw_response: Whether to return raw response content 89 | timeout: Request timeout in seconds 90 | 91 | Returns: 92 | Tuple of (response_data, status_code, response_headers) 93 | 94 | Raises: 95 | AuthenticationError: If authentication fails 96 | ValidationError: If request validation fails 97 | RateLimitError: If rate limit is exceeded 98 | ResourceNotFoundError: If resource is not found 99 | ServerError: If server returns 5xx error 100 | APIError: For other API errors 101 | RequestTimeoutError: If request times out 102 | NetworkError: If a network error occurs 103 | """ 104 | retry_decorator = retry( 105 | retry=retry_if_exception_type( 106 | ( 107 | requests.exceptions.Timeout, 108 | requests.exceptions.ConnectionError, 109 | ServerError, 110 | RequestTimeoutError, 111 | NetworkError, 112 | RateLimitError, 113 | ) 114 | ), 115 | wait=wait_exponential( 116 | multiplier=INITIAL_RETRY_DELAY, 117 | min=INITIAL_RETRY_DELAY, 118 | max=MAX_RETRY_DELAY, 119 | ), 120 | stop=stop_after_attempt(self._max_retries), 121 | ) 122 | 123 | _headers = {} if headers is None else headers.copy() 124 | 125 | @retry_decorator 126 | def _request_with_retry(): 127 | # Add authorization 128 | if self._client.api_key: 129 | _headers["Authorization"] = f"Bearer {self._client.api_key}" 130 | 131 | if "X-Client-Id" not in _headers: 132 | _headers["X-Client-Id"] = f"python-sdk-{__version__}" 133 | 134 | # Build full URL 135 | full_url = urljoin(self._base_url.rstrip("/") + "/", url.lstrip("/")) 136 | 137 | try: 138 | response = self._session.request( 139 | method=method, 140 | url=full_url, 141 | params=params, 142 | json=data, 143 | files=files, 144 | headers=_headers, 145 | timeout=timeout or self._timeout, 146 | ) 147 | 148 | response.raise_for_status() 149 | 150 | if raw_response: 151 | return ( 152 | response.content, 153 | response.status_code, 154 | dict(response.headers), 155 | ) 156 | return response.json(), response.status_code, dict(response.headers) 157 | 158 | except requests.exceptions.RequestException as e: 159 | if isinstance(e, requests.exceptions.HTTPError): 160 | # Extract error details from response 161 | try: 162 | error_data = e.response.json() 163 | # First try to get error from error object 164 | error_obj = error_data.get("error", {}) 165 | message = error_obj.get("message") 166 | # If not found, try to get detail directly 167 | if message is None: 168 | message = error_data.get("detail", str(e)) 169 | error_type = error_obj.get("type") 170 | request_id = error_obj.get("id") 171 | except Exception: 172 | message = str(e) 173 | error_type = None 174 | request_id = None 175 | 176 | status_code = e.response.status_code 177 | headers = dict(e.response.headers) 178 | 179 | if status_code == 401: 180 | raise AuthenticationError( 181 | message=message, 182 | http_status=status_code, 183 | headers=headers, 184 | request_id=request_id, 185 | error_type=error_type, 186 | ) from e 187 | elif status_code == 400: 188 | raise ValidationError( 189 | message=message, 190 | http_status=status_code, 191 | headers=headers, 192 | request_id=request_id, 193 | error_type=error_type, 194 | ) from e 195 | elif status_code == 404: 196 | raise ResourceNotFoundError( 197 | message=message, 198 | http_status=status_code, 199 | headers=headers, 200 | request_id=request_id, 201 | error_type=error_type, 202 | ) from e 203 | elif status_code == 429: 204 | raise RateLimitError( 205 | message=message, 206 | http_status=status_code, 207 | headers=headers, 208 | request_id=request_id, 209 | error_type=error_type, 210 | ) from e 211 | elif 500 <= status_code < 600: 212 | raise ServerError( 213 | message=message, 214 | http_status=status_code, 215 | headers=headers, 216 | request_id=request_id, 217 | error_type=error_type, 218 | ) from e 219 | else: 220 | raise APIError( 221 | message=message, 222 | http_status=status_code, 223 | headers=headers, 224 | request_id=request_id, 225 | error_type=error_type, 226 | ) from e 227 | elif isinstance(e, requests.exceptions.Timeout): 228 | raise RequestTimeoutError(f"Request timed out: {str(e)}") from e 229 | elif isinstance(e, requests.exceptions.ConnectionError): 230 | raise NetworkError(f"Connection error: {str(e)}") from e 231 | else: 232 | raise APIError(str(e)) from e 233 | 234 | try: 235 | return _request_with_retry() 236 | except RetryError as e: 237 | return self._handle_retry_error(e) 238 | 239 | def _handle_retry_error(self, e: RetryError) -> None: 240 | """Handle RetryError by extracting and raising the appropriate exception. 241 | 242 | Args: 243 | e: The RetryError to handle 244 | 245 | Raises: 246 | AuthenticationError: If authentication fails 247 | ValidationError: If request validation fails 248 | RateLimitError: If rate limit is exceeded 249 | ResourceNotFoundError: If resource is not found 250 | ServerError: If server returns 5xx error 251 | RequestTimeoutError: If request times out 252 | NetworkError: If a network error occurs 253 | APIError: For other API errors 254 | """ 255 | # Extract the last exception from the retry error 256 | last_exception = e.last_attempt.exception() 257 | 258 | # Preserve all our custom error types 259 | if isinstance( 260 | last_exception, 261 | ( 262 | AuthenticationError, 263 | ValidationError, 264 | RateLimitError, 265 | ResourceNotFoundError, 266 | ServerError, 267 | RequestTimeoutError, 268 | NetworkError, 269 | ), 270 | ): 271 | raise last_exception 272 | else: 273 | raise APIError( 274 | f"Request failed after {self._max_retries} retries: {str(last_exception)}" 275 | ) from last_exception 276 | -------------------------------------------------------------------------------- /tests/common/test_viz.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | from vlmrun.common.viz import ( 4 | DisplayOptions, 5 | xywh_to_xyxy, 6 | extract_bbox, 7 | get_boxes_from_response, 8 | ensure_image, 9 | render_bbox_image, 10 | render_image, 11 | get_nested_value, 12 | filter_response_data, 13 | format_json_html, 14 | show_results, 15 | ) 16 | from typing import List, Dict 17 | from pydantic import BaseModel 18 | 19 | 20 | @pytest.fixture 21 | def sample_image(): 22 | """Create a simple test image.""" 23 | img = Image.new("RGB", (100, 100), color="white") 24 | return img 25 | 26 | 27 | @pytest.fixture 28 | def sample_response(): 29 | """Create a sample response with various bounding box formats.""" 30 | return { 31 | "issuing_state": "AL", 32 | "license_number": "1234567", 33 | "full_name": "Connor Sample", 34 | "address": { 35 | "street": "10 Wonderful Drive", 36 | "city": "Montgomery", 37 | "state": "AL", 38 | "zip_code": "36110", 39 | "street_metadata": { 40 | "bbox": {"xywh": [0.349, 0.588, 0.406, 0.046]}, 41 | "bbox_content": "10 WONDERFUL DRIVE", 42 | "confidence": 0.9, 43 | }, 44 | "city_metadata": { 45 | "bbox": {"xywh": [0.347, 0.640, 0.187, 0.041]}, 46 | "bbox_content": "MONTGOMERY", 47 | "confidence": 1.0, 48 | }, 49 | }, 50 | "date_of_birth": "1948-01-05", 51 | "date_of_birth_metadata": { 52 | "bbox": {"xywh": [0.349, 0.431, 0.135, 0.049]}, 53 | "bbox_content": "01-05-1948", 54 | "confidence": 1.0, 55 | }, 56 | "full_name_metadata": { 57 | "bbox": {"xywh": [0.398, 0.783, 0.455, 0.172]}, 58 | "bbox_content": "Connor Sample", 59 | "confidence": 1.0, 60 | }, 61 | } 62 | 63 | 64 | def test_display_options_validation(): 65 | """Test DisplayOptions validation.""" 66 | opts = DisplayOptions(render_type="default") 67 | assert opts.is_valid_render_type 68 | 69 | opts = DisplayOptions(render_type="bboxes") 70 | assert opts.is_valid_render_type 71 | 72 | opts = DisplayOptions(render_type="invalid") 73 | assert not opts.is_valid_render_type 74 | 75 | opts = DisplayOptions(image_width=100) 76 | opts.validate_image_width() 77 | 78 | with pytest.raises(ValueError): 79 | DisplayOptions(image_width=-1).validate_image_width() 80 | 81 | with pytest.raises(ValueError): 82 | DisplayOptions(image_width=0).validate_image_width() 83 | 84 | 85 | def test_xywh_to_xyxy(): 86 | """Test conversion from XYWH to XYXY format.""" 87 | xywh = (0.1, 0.2, 0.3, 0.4) 88 | xyxy = xywh_to_xyxy(xywh) 89 | expected = (0.1, 0.2, 0.4, 0.6) 90 | assert all(pytest.approx(a) == b for a, b in zip(xyxy, expected)) 91 | 92 | xywh = (10, 20, 30, 40) 93 | xyxy = xywh_to_xyxy(xywh) 94 | expected = (10, 20, 40, 60) 95 | assert all(pytest.approx(a) == b for a, b in zip(xyxy, expected)) 96 | 97 | 98 | def test_extract_bbox(): 99 | """Test bounding box extraction from various formats.""" 100 | result = extract_bbox([0.1, 0.2, 0.3, 0.4]) 101 | expected = (0.1, 0.2, 0.3, 0.4) 102 | assert all(pytest.approx(a) == b for a, b in zip(result, expected)) 103 | 104 | result = extract_bbox({"bbox": [0.1, 0.2, 0.3, 0.4]}) 105 | expected = (0.1, 0.2, 0.3, 0.4) 106 | assert all(pytest.approx(a) == b for a, b in zip(result, expected)) 107 | 108 | result = extract_bbox({"xywh": [0.1, 0.2, 0.3, 0.4]}) 109 | expected = (0.1, 0.2, 0.4, 0.6) 110 | assert all(pytest.approx(a) == b for a, b in zip(result, expected)) 111 | 112 | result = extract_bbox({"bbox": {"xywh": [0.1, 0.2, 0.3, 0.4]}}) 113 | expected = (0.1, 0.2, 0.4, 0.6) 114 | assert all(pytest.approx(a) == b for a, b in zip(result, expected)) 115 | 116 | assert extract_bbox({"invalid": "format"}) is None 117 | assert extract_bbox([1, 2, 3]) is None 118 | 119 | 120 | def test_get_boxes_from_response(sample_response): 121 | """Test extraction of bounding boxes from response.""" 122 | boxes = get_boxes_from_response(sample_response) 123 | assert len(boxes) == 4 # street, city, dob, full_name metadata 124 | 125 | # Check street metadata extraction 126 | street_box = next(box for box in boxes if box.get("field") == "address.street") 127 | expected = (0.349, 0.588, 0.755, 0.634) # xywh converted to xyxy 128 | assert all(pytest.approx(a) == b for a, b in zip(street_box["bbox"], expected)) 129 | assert street_box["content"] == "10 WONDERFUL DRIVE" 130 | assert pytest.approx(street_box["confidence"]) == 0.9 131 | 132 | # Check city metadata extraction 133 | city_box = next(box for box in boxes if box.get("field") == "address.city") 134 | expected = (0.347, 0.640, 0.534, 0.681) # xywh converted to xyxy 135 | assert all(pytest.approx(a) == b for a, b in zip(city_box["bbox"], expected)) 136 | assert city_box["content"] == "MONTGOMERY" 137 | assert pytest.approx(city_box["confidence"]) == 1.0 138 | 139 | # Check date of birth metadata extraction 140 | dob_box = next(box for box in boxes if box.get("field") == "date_of_birth") 141 | expected = (0.349, 0.431, 0.484, 0.480) # xywh converted to xyxy 142 | assert all(pytest.approx(a) == b for a, b in zip(dob_box["bbox"], expected)) 143 | assert dob_box["content"] == "01-05-1948" 144 | assert pytest.approx(dob_box["confidence"]) == 1.0 145 | 146 | # Check full name metadata extraction 147 | name_box = next(box for box in boxes if box.get("field") == "full_name") 148 | expected = (0.398, 0.783, 0.853, 0.955) # xywh converted to xyxy 149 | assert all(pytest.approx(a) == b for a, b in zip(name_box["bbox"], expected)) 150 | assert name_box["content"] == "Connor Sample" 151 | assert pytest.approx(name_box["confidence"]) == 1.0 152 | 153 | 154 | def test_ensure_image(sample_image, tmp_path): 155 | """Test image loading and validation.""" 156 | assert ensure_image(sample_image) == sample_image 157 | 158 | img_path = tmp_path / "test.png" 159 | sample_image.save(img_path) 160 | loaded_img = ensure_image(img_path) 161 | assert isinstance(loaded_img, Image.Image) 162 | assert loaded_img.size == (100, 100) 163 | 164 | with pytest.raises(ValueError): 165 | ensure_image(tmp_path / "nonexistent.png") 166 | 167 | with pytest.raises(ValueError): 168 | ensure_image(123) 169 | 170 | 171 | def test_render_bbox_image(sample_image, sample_response): 172 | """Test bounding box rendering on image.""" 173 | result = render_bbox_image(sample_image, sample_response) 174 | assert isinstance(result, Image.Image) 175 | 176 | result = render_bbox_image(sample_image, sample_response, return_base64=True) 177 | assert isinstance(result, str) 178 | assert result.startswith('