├── .coverage ├── .github └── workflows │ ├── build-tag-and-publish.yml │ ├── e2e-cloud-gpu.yml │ ├── e2e-local.yml │ └── test-and-lint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── examples ├── 1_quickstart.ipynb ├── 2_custom_evaluation_criteria.ipynb ├── 3_evaluation_strategies.ipynb ├── 4_llama_index_evaluators.ipynb ├── 5_evaluate_haystack_rag_pipeline.ipynb ├── 6_langchain_evaluators.ipynb ├── 7_baseten_async_quickstart.ipynb └── sample_data │ └── csr_assistant.json ├── flow_judge ├── __init__.py ├── eval_data_types.py ├── flow_judge.py ├── integrations │ ├── __init__.py │ ├── haystack.py │ ├── langchain.py │ └── llama_index.py ├── metrics │ ├── __init__.py │ ├── metric.py │ └── presets.py ├── models │ ├── __init__.py │ ├── adapters │ │ ├── base.py │ │ └── baseten │ │ │ ├── README.md │ │ │ ├── adapter.py │ │ │ ├── api_auth.py │ │ │ ├── data_io.py │ │ │ ├── deploy.py │ │ │ ├── deployment │ │ │ ├── config.yaml │ │ │ └── model │ │ │ │ ├── __init__.py │ │ │ │ ├── helper.py │ │ │ │ └── model.py │ │ │ ├── errors.py │ │ │ ├── gpu.py │ │ │ ├── management.py │ │ │ ├── token_bucket.py │ │ │ ├── util.py │ │ │ ├── validation.py │ │ │ └── webhook.py │ ├── baseten.py │ ├── common.py │ ├── huggingface.py │ ├── llamafile.py │ └── vllm.py └── utils │ ├── __init__.py │ ├── prompt_formatter.py │ ├── result_writer.py │ └── validators.py ├── img └── flow_judge_banner.png ├── pyproject.toml └── tests ├── README.md ├── e2e-cloud-gpu └── models │ └── adapters │ └── test_baseten_e2e.py ├── e2e-local ├── integrations │ └── test_llama_index_e2e.py └── models │ └── test_llamafile_e2e.py └── unit ├── models ├── adapters │ ├── baseten.py │ ├── gpu.py │ └── validation.py ├── test_baseten.py └── test_llamafile_unit.py ├── test_flow_judge.py ├── test_metrics.py ├── test_utils.py └── utils └── test_result_writer.py /.coverage: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flowaicom/flow-judge/e56d199db4d79f184dac1e9ab2da83992acda14d/.coverage -------------------------------------------------------------------------------- /.github/workflows/build-tag-and-publish.yml: -------------------------------------------------------------------------------- 1 | name: Build, tag and publish to PyPI and TestPyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - 'v[0-9]+.[0-9]+.[0-9]+' 9 | paths-ignore: 10 | - '**.md' 11 | - 'docs/**' 12 | pull_request: 13 | branches: [main] 14 | paths-ignore: 15 | - '**.md' 16 | - 'docs/**' 17 | workflow_dispatch: 18 | 19 | env: 20 | PYTHON_VERSION: '3.11' 21 | 22 | jobs: 23 | check: 24 | name: Check workflow (and have one non-dependent job) 25 | runs-on: self-hosted 26 | steps: 27 | - name: Checkout code 28 | uses: actions/checkout@v4 29 | - name: Check for necessary files 30 | run: | 31 | if [ ! -f "pyproject.toml" ]; then 32 | echo "pyproject.toml is missing" 33 | exit 1 34 | fi 35 | echo "Build file is present" 36 | 37 | build: 38 | name: Build distribution 📦 39 | runs-on: self-hosted 40 | needs: [check] 41 | if: ${{ always() && needs.check.result == 'success' }} 42 | 43 | steps: 44 | - uses: actions/checkout@v4 45 | - name: Set up Python ${{ env.PYTHON_VERSION }} 46 | uses: actions/setup-python@v5 47 | with: 48 | python-version: ${{ env.PYTHON_VERSION }} 49 | - name: Install pypa/build 50 | run: python3 -m pip install build --user 51 | - name: Build a binary wheel and a source tarball 52 | run: python3 -m build 53 | - name: Store the distribution packages 54 | uses: actions/upload-artifact@v4 55 | with: 56 | name: python-package-distributions 57 | path: dist/ 58 | 59 | publish-to-testpypi: 60 | name: Publish Python 🐍 distribution 📦 to TestPyPI 61 | needs: [build] 62 | runs-on: self-hosted 63 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 64 | environment: 65 | name: testpypi 66 | url: https://test.pypi.org/p/flow-judge 67 | 68 | permissions: 69 | id-token: write 70 | 71 | steps: 72 | - name: Download all the dists 73 | uses: actions/download-artifact@v4 74 | with: 75 | name: python-package-distributions 76 | path: dist/ 77 | - name: Publish distribution 📦 to TestPyPI 78 | uses: pypa/gh-action-pypi-publish@release/v1 79 | with: 80 | repository-url: https://test.pypi.org/legacy/ 81 | user: __token__ 82 | password: ${{ secrets.TESTPYPI_API_TOKEN }} 83 | 84 | publish-to-pypi: 85 | name: Publish Python 🐍 distribution 📦 to PyPI 86 | needs: [build, publish-to-testpypi] 87 | runs-on: self-hosted 88 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 89 | environment: 90 | name: pypi 91 | url: https://pypi.org/p/flow-judge 92 | permissions: 93 | id-token: write 94 | 95 | steps: 96 | - name: Download all the dists 97 | uses: actions/download-artifact@v4 98 | with: 99 | name: python-package-distributions 100 | path: dist/ 101 | - name: Verify tag format 102 | run: | 103 | if [[ ! "${{ github.ref }}" =~ ^refs/tags/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then 104 | echo "Invalid tag format. Expected format: v*.*.*" 105 | exit 1 106 | fi 107 | - name: Publish distribution 📦 to PyPI 108 | uses: pypa/gh-action-pypi-publish@release/v1 109 | with: 110 | user: __token__ 111 | password: ${{ secrets.PYPI_API_TOKEN }} 112 | 113 | github-release: 114 | name: Create GitHub Release 115 | needs: [publish-to-pypi] 116 | runs-on: self-hosted 117 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 118 | 119 | permissions: 120 | contents: write 121 | id-token: write 122 | 123 | steps: 124 | - name: Download all the dists 125 | uses: actions/download-artifact@v4 126 | with: 127 | name: python-package-distributions 128 | path: dist/ 129 | - name: Sign the dists with Sigstore 130 | uses: sigstore/gh-action-sigstore-python@v2.1.1 131 | with: 132 | inputs: ./dist/*.tar.gz ./dist/*.whl 133 | - name: Create GitHub Release 134 | env: 135 | GITHUB_TOKEN: ${{ github.token }} 136 | run: | 137 | gh release create "${{ github.ref_name }}" \ 138 | --repo "${{ github.repository }}" \ 139 | --title "Release ${{ github.ref_name }}" \ 140 | --notes "Release notes for version ${{ github.ref_name }}" 141 | - name: Upload artifact signatures to GitHub Release 142 | env: 143 | GITHUB_TOKEN: ${{ github.token }} 144 | run: gh release upload "${{ github.ref_name }}" dist/** --repo "${{ github.repository }}" 145 | -------------------------------------------------------------------------------- /.github/workflows/e2e-cloud-gpu.yml: -------------------------------------------------------------------------------- 1 | name: E2E Test (cloud engines, GPU enabled) 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * 0' # Runs at 00:00 UTC every Sunday 6 | pull_request: 7 | types: [ready_for_review] 8 | branches: [ main ] 9 | workflow_dispatch: 10 | 11 | jobs: 12 | 13 | lint: 14 | runs-on: self-hosted 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python 3.11 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: '3.11' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install ruff==0.1.3 black==23.9.1 isort==5.12.0 pyupgrade==3.10.1 25 | - name: Lint with ruff 26 | run: | 27 | ruff check . --config pyproject.toml 28 | - name: Format with black 29 | run: | 30 | black --check . --config pyproject.toml 31 | - name: Sort imports with isort 32 | run: | 33 | isort --check-only . --settings-path pyproject.toml 34 | - name: Check for Python upgrades 35 | run: | 36 | files=$(git ls-files '*.py') 37 | if pyupgrade --py310-plus $files | grep -q '^---'; then 38 | echo "pyupgrade would make changes. Please run pyupgrade locally and commit the changes." 39 | exit 1 40 | fi 41 | 42 | test: 43 | needs: lint 44 | runs-on: self-hosted 45 | strategy: 46 | matrix: 47 | python-version: ['3.11'] 48 | 49 | steps: 50 | - uses: actions/checkout@v4 51 | - name: Set up Python ${{ matrix.python-version }} 52 | uses: actions/setup-python@v5 53 | with: 54 | python-version: ${{ matrix.python-version }} 55 | - name: Install dependencies 56 | run: | 57 | python -m pip install --upgrade pip 58 | pip install .[dev,vllm,hf,llamafile,integrations-test,baseten] 59 | - name: Verify GPU availability 60 | run: | 61 | nvidia-smi 62 | python -c "import torch; print(torch.cuda.is_available())" 63 | - name: Test with pytest and generate coverage 64 | run: | 65 | export HF_HOME=/tmp/hf_home 66 | export TRANSFORMERS_CACHE=/tmp/hf_home 67 | export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} 68 | export BASETEN_API_KEY=${{ secrets.BASETEN_API_KEY }} 69 | export BASETEN_MODEL_ID=${{ secrets.BASETEN_MODEL_ID }} 70 | export BASETEN_WEBHOOK_SECRET=${{ secrets.BASETEN_WEBHOOK_SECRET }} 71 | export BASETEN_WEBHOOK_URL=${{ secrets.BASETEN_WEBHOOK_URL }} 72 | pytest ./tests/e2e-cloud-gpu --cov=./ --junitxml=junit.xml 73 | - name: Upload coverage to Codecov 74 | uses: codecov/codecov-action@v4 75 | with: 76 | token: ${{ secrets.CODECOV_TOKEN }} 77 | fail_ci_if_error: true 78 | - name: Upload test results to Codecov 79 | if: ${{ !cancelled() }} 80 | uses: codecov/test-results-action@v1 81 | with: 82 | token: ${{ secrets.CODECOV_TOKEN }} 83 | -------------------------------------------------------------------------------- /.github/workflows/e2e-local.yml: -------------------------------------------------------------------------------- 1 | name: E2E Test (local engines, GPU enabled) 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | 11 | lint: 12 | runs-on: self-hosted 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 3.11 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.11' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install ruff==0.1.3 black==23.9.1 isort==5.12.0 pyupgrade==3.10.1 23 | - name: Lint with ruff 24 | run: | 25 | ruff check . --config pyproject.toml 26 | - name: Format with black 27 | run: | 28 | black --check . --config pyproject.toml 29 | - name: Sort imports with isort 30 | run: | 31 | isort --check-only . --settings-path pyproject.toml 32 | - name: Check for Python upgrades 33 | run: | 34 | files=$(git ls-files '*.py') 35 | if pyupgrade --py310-plus $files | grep -q '^---'; then 36 | echo "pyupgrade would make changes. Please run pyupgrade locally and commit the changes." 37 | exit 1 38 | fi 39 | 40 | test: 41 | needs: lint 42 | runs-on: self-hosted 43 | strategy: 44 | matrix: 45 | python-version: ['3.10', '3.11', '3.12'] 46 | 47 | steps: 48 | - uses: actions/checkout@v4 49 | - name: Set up Python ${{ matrix.python-version }} 50 | uses: actions/setup-python@v5 51 | with: 52 | python-version: ${{ matrix.python-version }} 53 | - name: Install dependencies 54 | run: | 55 | python -m pip install --upgrade pip 56 | pip install .[dev,vllm,hf,llamafile,integrations-test,baseten] 57 | - name: Verify GPU availability 58 | run: | 59 | nvidia-smi 60 | python -c "import torch; print(torch.cuda.is_available())" 61 | - name: Test with pytest and generate coverage 62 | run: | 63 | export HF_HOME=/tmp/hf_home 64 | export TRANSFORMERS_CACHE=/tmp/hf_home 65 | export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} 66 | pytest ./tests/e2e-local --cov=./ --junitxml=junit.xml 67 | - name: Upload coverage to Codecov 68 | uses: codecov/codecov-action@v4 69 | with: 70 | token: ${{ secrets.CODECOV_TOKEN }} 71 | fail_ci_if_error: true 72 | - name: Upload test results to Codecov 73 | if: ${{ !cancelled() }} 74 | uses: codecov/test-results-action@v1 75 | with: 76 | token: ${{ secrets.CODECOV_TOKEN }} 77 | -------------------------------------------------------------------------------- /.github/workflows/test-and-lint.yml: -------------------------------------------------------------------------------- 1 | name: Test and Lint 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | 11 | lint: 12 | runs-on: self-hosted 13 | steps: 14 | - uses: actions/checkout@v4 15 | with: 16 | fetch-depth: 0 # Fetch all history for trufflehog 17 | - name: Set up Python 3.11 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: '3.11' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install ruff==0.1.3 black==23.9.1 isort==5.12.0 pyupgrade==3.10.1 25 | - name: Lint with ruff 26 | run: | 27 | ruff check . --config pyproject.toml 28 | - name: Format with black 29 | run: | 30 | black --check . --config pyproject.toml 31 | - name: Sort imports with isort 32 | run: | 33 | isort --check-only . --settings-path pyproject.toml 34 | - name: Check for Python upgrades 35 | run: | 36 | files=$(git ls-files '*.py') 37 | if pyupgrade --py310-plus $files | grep -q '^---'; then 38 | echo "pyupgrade would make changes. Please run pyupgrade locally and commit the changes." 39 | exit 1 40 | fi 41 | - name: Run TruffleHog 42 | uses: trufflesecurity/trufflehog@v3.82.6 43 | with: 44 | path: ./ 45 | base: ${{ github.event.pull_request.base.sha }} 46 | head: ${{ github.event.pull_request.head.sha }} 47 | extra_args: --only-verified --exclude-globs=".venv/*" 48 | 49 | test: 50 | needs: lint 51 | runs-on: self-hosted 52 | strategy: 53 | matrix: 54 | python-version: ['3.10', '3.11', '3.12'] 55 | 56 | steps: 57 | - uses: actions/checkout@v4 58 | - name: Set up Python ${{ matrix.python-version }} 59 | uses: actions/setup-python@v5 60 | with: 61 | python-version: ${{ matrix.python-version }} 62 | - name: Install dependencies 63 | run: | 64 | python -m pip install --upgrade pip 65 | pip install .[dev,vllm,hf,llamafile,integrations-test,baseten] 66 | - name: Verify GPU availability 67 | run: | 68 | nvidia-smi 69 | python -c "import torch; print(torch.cuda.is_available())" 70 | - name: Test with pytest and generate coverage 71 | run: | 72 | export HF_HOME=/tmp/hf_home 73 | export TRANSFORMERS_CACHE=/tmp/hf_home 74 | export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} 75 | pytest ./tests/unit --cov=./ --junitxml=junit.xml 76 | - name: Upload coverage to Codecov 77 | uses: codecov/codecov-action@v4 78 | with: 79 | token: ${{ secrets.CODECOV_TOKEN }} 80 | fail_ci_if_error: true 81 | - name: Upload test results to Codecov 82 | if: ${{ !cancelled() }} 83 | uses: codecov/test-results-action@v1 84 | with: 85 | token: ${{ secrets.CODECOV_TOKEN }} 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python cache files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # Virtual environments 6 | venv/ 7 | env/ 8 | .venv/ 9 | 10 | # uv 11 | uv.lock 12 | 13 | # IDE settings 14 | .vscode/ 15 | .idea/ 16 | 17 | # Distribution / packaging 18 | dist/ 19 | build/ 20 | *.egg-info/ 21 | 22 | # Logs 23 | *.log 24 | 25 | # Local configuration 26 | .env 27 | 28 | # Jupyter Notebook 29 | .ipynb_checkpoints 30 | 31 | # ruff cache 32 | .ruff_cache/ 33 | 34 | # mypy cache 35 | .mypy_cache/ 36 | 37 | # pytype cache 38 | .pytype/ 39 | 40 | # pyc 41 | *.pyc 42 | 43 | # output 44 | output/ 45 | 46 | # data 47 | data/ 48 | 49 | .cache 50 | 51 | flake.nix 52 | flake.lock 53 | .direnv 54 | 55 | .hypothesis 56 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-ast 9 | - id: check-json 10 | - id: check-merge-conflict 11 | - id: detect-private-key 12 | 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.1.3 15 | hooks: 16 | - id: ruff 17 | args: [--fix, --exit-non-zero-on-fix, --config, pyproject.toml] 18 | 19 | - repo: https://github.com/psf/black 20 | rev: 23.9.1 21 | hooks: 22 | - id: black 23 | args: [--check, --config, pyproject.toml] 24 | language_version: python3.11 25 | 26 | - repo: https://github.com/PyCQA/isort 27 | rev: 5.12.0 28 | hooks: 29 | - id: isort 30 | args: [--check-only, --settings-path, pyproject.toml] 31 | 32 | - repo: https://github.com/asottile/pyupgrade 33 | rev: v3.10.1 34 | hooks: 35 | - id: pyupgrade 36 | args: [--py310-plus] 37 | 38 | - repo: https://github.com/trufflesecurity/trufflehog 39 | rev: v3.82.6 40 | hooks: 41 | - id: trufflehog 42 | name: TruffleHog 43 | description: Detect secrets in your data with TruffleHog. 44 | entry: trufflehog git file://. --only-verified --no-update --fail --exclude-globs=".venv/*" 45 | language: system 46 | -------------------------------------------------------------------------------- /examples/7_baseten_async_quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Quickstart\n", 8 | "\n", 9 | "This tutorial demonstrates how to use the `Baseten` model class in async mode to perform language model-based evaluations using Flow-Judge-v0.1 deployed model on Baseten. For detailed instructions on how to use Baseten, visit the [Baseten readme](https://github.com/flowaicom/flow-judge/blob/main/flow_judge/models/adapters/baseten/README.md)." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Setup\n", 17 | "\n", 18 | "Let's instantiate the `Baseten` model class in async mode. The async implementation makes use of Baseten's async inference approach. See [here](https://docs.baseten.co/invoke/async).\n", 19 | "\n", 20 | "You can imagine this as *fire-and-forget* functionality. Completion requests are made to the deployed model, once data is processed and inference is complete, the output is sent to a predefined webhook. The webhook url is part of the original request. The `Flow-Judge` library then connects with the webhook and *listens* for a response. The library makes use of this approach to allow configurability for concurrent execution.\n", 21 | "\n", 22 | "Optionally Flow AI has deployed a webhook proxy that accepts this request signature and feeds-it-forward to the client. This can be found under the URL: \"https://proxy.flow-ai.dev\"\n", 23 | "\n", 24 | "### Pre-requisite\n", 25 | "\n", 26 | "1. Sign-up to [Baseten](https://www.baseten.co/)\n", 27 | "2. Generate a Baseten API Key from [here](https://app.baseten.co/settings/api_keys)\n", 28 | "3. Generate a Webhook secret from [here](https://app.baseten.co/settings/secrets)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "### Additional Requirements\n", 36 | "\n", 37 | "Set your `Baseten API key`, `Webhook secret` and `GPU` option in the environment." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "import os\n", 47 | "\n", 48 | "os.environ[\"BASETEN_WEBHOOK_SECRET\"] = \"your_baseten_webhook_secret\"\n", 49 | "os.environ[\"BASETEN_API_KEY\"] = \"your_baseten_api_key\"\n", 50 | "\n", 51 | "# You can optionally switch the GPU to H100.\n", 52 | "# This will deploy the FlowJudge model on H100 40GB\n", 53 | "# A10G deployment is Flow-Judge-v0.1-AWQ\n", 54 | "# H100 deployment is Flow-Judge-v0.1-FP8\n", 55 | "# !! Manually changing the hardware on Baseten's UI may cause compatibility issues !!\n", 56 | "os.environ[\"BASETEN_GPU\"] = \"A10G\"" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "### Instantiate the Baseten model\n", 64 | "\n", 65 | "Set the following required options for async execution mode of the Baseten model class: \n", 66 | "1. `exec_async=True`\n", 67 | "2. `webhook_proxy_url=https://proxy.flow-ai.dev` (or [run the proxy locally](https://github.com/flowaicom/flow-judge/blob/main/flow_judge/models/adapters/baseten/README.md))\n", 68 | "\n", 69 | "Optionally you can set the `async_batch_size` option to a value > 0 (defaults to `128`). This is the number of concurrent requests sent to the deployed model. It is associated with the concurrency goals you want to achieve and can be actively configured in Baseten's UI. For more information, see [here](https://docs.baseten.co/performance/concurrency). Our current deployment configuration allows a concurrency target of `128` and max replica of `1` for the deployed model as the default on Baseten. This means if you have max replica set to 1 on Baseten, it can accept concurrent requests of `128`. The batch size you set for the Baseten model class should be equivalent to the number of `concurrency_target * number_of_replicas`" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "from flow_judge import Baseten, AsyncFlowJudge\n", 79 | "from flow_judge.metrics import RESPONSE_FAITHFULNESS_5POINT\n", 80 | "\n", 81 | "# Async model execution\n", 82 | "model = Baseten(\n", 83 | " webhook_proxy_url=\"https://proxy.flow-ai.dev\",\n", 84 | " exec_async=True,\n", 85 | ")\n", 86 | "\n", 87 | "# Instantiate the Async Judge with the model and a metric\n", 88 | "# The library includes multiple default metrics and you can implement your own.\n", 89 | "faithfulness_judge = AsyncFlowJudge(\n", 90 | " metric=RESPONSE_FAITHFULNESS_5POINT,\n", 91 | " model=model,\n", 92 | ")" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "## Running Evaluations\n", 100 | "\n", 101 | "Let's test batched evaluations with our example csr data on the faithfulness 5 point likert.\n", 102 | "\n", 103 | "We use the `async_batch_evaluate` method from the AsyncFlowJudge class. Underneath this uses batched processing utilizing the batch_size set with the `async_batch_size` argument of the Baseten model class. If there are failures, for example with networking, the batch will process and errors will be propagated as log outputs. The output would include the successful responses." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# Read the sample data\n", 113 | "import json\n", 114 | "from flow_judge import EvalInput\n", 115 | "with open(\"sample_data/csr_assistant.json\", \"r\") as f:\n", 116 | " data = json.load(f)\n", 117 | "\n", 118 | "# Create a list of inputs and outputs\n", 119 | "inputs_batch = [\n", 120 | " [\n", 121 | " {\"query\": sample[\"query\"]},\n", 122 | " {\"context\": sample[\"context\"]},\n", 123 | " ]\n", 124 | " for sample in data\n", 125 | "]\n", 126 | "outputs_batch = [{\"response\": sample[\"response\"]} for sample in data]\n", 127 | "\n", 128 | "# Create a list of EvalInput\n", 129 | "eval_inputs_batch = [EvalInput(inputs=inputs, output=output) for inputs, output in zip(inputs_batch, outputs_batch)]\n", 130 | "\n", 131 | "# Run the batch evaluation\n", 132 | "results = await faithfulness_judge.async_batch_evaluate(eval_inputs_batch, save_results=False)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "from IPython.display import Markdown, display\n", 142 | "\n", 143 | "# Visualizing the results\n", 144 | "for i, result in enumerate(results):\n", 145 | " display(Markdown(f\"__Sample {i+1}:__\"))\n", 146 | " display(Markdown(f\"__Feedback:__\\n{result.feedback}\\n\\n__Score:__\\n{result.score}\"))\n", 147 | " display(Markdown(\"---\"))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "Similarly you can run a single evaluation task using the `async_evaluate` method on the `AsyncFlowJudge` class. Under the hood, this will process a single async request and attach listeners to the webhook for the response." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "result = await faithfulness_judge.async_evaluate(eval_inputs_batch[0], save_results=False)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# Display the result\n", 173 | "display(Markdown(f\"__Feedback:__\\n{result.feedback}\\n\\n__Score:__\\n{result.score}\"))" 174 | ] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "flow-judge-test", 180 | "language": "python", 181 | "name": "python3" 182 | }, 183 | "language_info": { 184 | "codemirror_mode": { 185 | "name": "ipython", 186 | "version": 3 187 | }, 188 | "file_extension": ".py", 189 | "mimetype": "text/x-python", 190 | "name": "python", 191 | "nbconvert_exporter": "python", 192 | "pygments_lexer": "ipython3", 193 | "version": "3.11.9" 194 | } 195 | }, 196 | "nbformat": 4, 197 | "nbformat_minor": 2 198 | } 199 | -------------------------------------------------------------------------------- /flow_judge/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import PackageNotFoundError, version 2 | 3 | from flow_judge.eval_data_types import EvalInput, EvalOutput 4 | from flow_judge.flow_judge import AsyncFlowJudge, FlowJudge 5 | from flow_judge.metrics import CustomMetric, Metric, RubricItem, list_all_metrics 6 | from flow_judge.models.common import BaseFlowJudgeModel 7 | from flow_judge.utils.prompt_formatter import format_rubric, format_user_prompt, format_vars 8 | 9 | try: 10 | __version__ = version("flow-judge") 11 | except PackageNotFoundError: 12 | # package is not installed 13 | __version__ = "unknown" 14 | 15 | __all__ = [ 16 | "FlowJudge", 17 | "AsyncFlowJudge", 18 | "EvalInput", 19 | "format_vars", 20 | "format_rubric", 21 | "format_user_prompt", 22 | "RubricItem", 23 | "Metric", 24 | "CustomMetric", 25 | "BaseFlowJudgeModel", 26 | "EvalOutput", 27 | ] 28 | 29 | # Conditional imports for optional dependencies 30 | try: 31 | from flow_judge.models.huggingface import Hf 32 | 33 | __all__.append("Hf") 34 | except ImportError: 35 | Hf = None 36 | 37 | try: 38 | from flow_judge.models.vllm import Vllm 39 | 40 | __all__.append("Vllm") 41 | except ImportError: 42 | Vllm = None 43 | 44 | try: 45 | from flow_judge.models.llamafile import Llamafile 46 | 47 | __all__.append("Llamafile") 48 | except ImportError: 49 | Llamafile = None 50 | 51 | try: 52 | from flow_judge.models.baseten import Baseten 53 | 54 | __all__.append("Baseten") 55 | except ImportError: 56 | Baseten = None 57 | 58 | 59 | def get_available_models(): 60 | """Return a list of available model classes based on installed extras.""" 61 | models = [BaseFlowJudgeModel] 62 | if Hf is not None: 63 | models.append(Hf) 64 | if Vllm is not None: 65 | models.append(Vllm) 66 | if Llamafile is not None: 67 | models.append(Llamafile) 68 | if Baseten is not None: 69 | models.append(Baseten) 70 | return models 71 | 72 | 73 | __all__.append("get_available_models") 74 | 75 | # Add all metric names to __all__ 76 | __all__ += list_all_metrics() 77 | -------------------------------------------------------------------------------- /flow_judge/eval_data_types.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class EvalInput(BaseModel): 10 | """Input for evaluation.""" 11 | 12 | inputs: list[dict[str, str]] = Field(default_factory=list) 13 | output: dict[str, str] 14 | 15 | 16 | class EvalOutput(BaseModel): 17 | """Output model for evaluation results.""" 18 | 19 | feedback: str = Field(..., description="Feedback from the evaluation") 20 | score: int | None = Field(..., description="Numeric score from the evaluation") 21 | 22 | @classmethod 23 | def parse(cls, response: str, fail_on_parse_error: bool = False) -> "EvalOutput": 24 | """Parse the evaluation response from the judge.""" 25 | try: 26 | # Compile regex patterns 27 | feedback_pattern = re.compile(r"\s*(.*?)\s*", re.DOTALL) 28 | score_pattern = re.compile(r"\s*(\d+)\s*", re.DOTALL) 29 | 30 | feedback_match = feedback_pattern.search(response) 31 | score_match = score_pattern.search(response) 32 | 33 | if not feedback_match or not score_match: 34 | raise ValueError("Failed to parse evaluation response.") 35 | 36 | feedback = feedback_match.group(1).strip() 37 | score = int(score_match.group(1).strip()) 38 | 39 | return cls(feedback=feedback, score=score) 40 | except Exception as e: 41 | if fail_on_parse_error: 42 | raise ValueError(f"Failed to parse evaluation response: {e}") from e 43 | logger.warning(f"Parsing failed for response: {response}. Error: {e}") 44 | return EvalOutput(feedback="Error", score=-1) 45 | -------------------------------------------------------------------------------- /flow_judge/flow_judge.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from flow_judge.eval_data_types import EvalInput, EvalOutput 5 | from flow_judge.metrics import CustomMetric, Metric 6 | from flow_judge.models.adapters.baseten.data_io import BatchResult 7 | from flow_judge.models.adapters.baseten.errors import FlowJudgeError 8 | from flow_judge.models.common import AsyncBaseFlowJudgeModel, BaseFlowJudgeModel 9 | from flow_judge.utils.prompt_formatter import format_rubric, format_user_prompt, format_vars 10 | from flow_judge.utils.result_writer import write_results_to_disk 11 | from flow_judge.utils.validators import validate_eval_input 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class BaseFlowJudge: 18 | """Base class for FlowJudge with common functionality.""" 19 | 20 | def __init__( 21 | self, 22 | metric: Metric | CustomMetric, 23 | model: BaseFlowJudgeModel | AsyncBaseFlowJudgeModel, 24 | output_dir: str | None = "output/", 25 | ): 26 | """Initialize BaseFlowJudge with a metric and model.""" 27 | if not isinstance(metric, (Metric, CustomMetric)): 28 | raise ValueError("Invalid metric type. Use Metric or CustomMetric.") 29 | self.metric = metric 30 | self.output_dir = output_dir 31 | self.model = model 32 | 33 | def _format_prompt(self, eval_input: EvalInput) -> str: 34 | """Format the prompt for a single evaluation input.""" 35 | prompt_variables = { 36 | "INPUTS": format_vars(eval_input.inputs), 37 | "OUTPUT": format_vars([eval_input.output]), 38 | "EVALUATION_CRITERIA": self.metric.criteria, 39 | "RUBRIC": format_rubric(self.metric.rubric), 40 | } 41 | return format_user_prompt(prompt_variables) 42 | 43 | def _validate_inputs(self, eval_inputs: EvalInput | list[EvalInput]): 44 | """Validate required inputs and output against the metric.""" 45 | if isinstance(eval_inputs, list): 46 | for eval_input in eval_inputs: 47 | validate_eval_input(eval_input, self.metric) 48 | else: 49 | validate_eval_input(eval_inputs, self.metric) 50 | 51 | def _save_results( 52 | self, eval_inputs: list[EvalInput], eval_outputs: list[EvalOutput], append: bool = False 53 | ): 54 | """Save results to disk.""" 55 | logger.info(f"{'Appending' if append else 'Saving'} results to {self.output_dir}") 56 | write_results_to_disk( 57 | eval_inputs, 58 | eval_outputs, 59 | self.model.metadata, 60 | self.metric.name, 61 | self.output_dir, 62 | append=append, 63 | ) 64 | 65 | 66 | class FlowJudge(BaseFlowJudge): 67 | """Synchronous FlowJudge class for evaluating AI outputs.""" 68 | 69 | def __init__( 70 | self, 71 | metric: Metric | CustomMetric, 72 | model: BaseFlowJudgeModel, 73 | output_dir: str | None = "output/", 74 | ): 75 | """Initialize FlowJudge with a metric and model.""" 76 | super().__init__(metric, model, output_dir) 77 | if not isinstance(model, BaseFlowJudgeModel): 78 | raise ValueError("Invalid model type. Use BaseFlowJudgeModel or its subclasses.") 79 | 80 | def evaluate(self, eval_input: EvalInput, save_results: bool = False) -> EvalOutput: 81 | """Evaluate a single EvalInput object.""" 82 | try: 83 | self._validate_inputs(eval_input) 84 | prompt = self._format_prompt(eval_input) 85 | response = self.model._generate(prompt) 86 | eval_output = EvalOutput.parse(response) 87 | if save_results: 88 | self._save_results([eval_input], [eval_output]) 89 | return eval_output 90 | except Exception as e: 91 | logger.error(f"Evaluation failed: {e}") 92 | raise 93 | 94 | def batch_evaluate( 95 | self, 96 | eval_inputs: list[EvalInput], 97 | use_tqdm: bool = True, 98 | save_results: bool = True, 99 | fail_on_parse_error: bool = False, 100 | ) -> list[EvalOutput]: 101 | """Batch evaluate a list of EvalInput objects.""" 102 | self._validate_inputs(eval_inputs) 103 | prompts = [self._format_prompt(eval_input) for eval_input in eval_inputs] 104 | responses = self.model._batch_generate(prompts, use_tqdm=use_tqdm) 105 | eval_outputs = [ 106 | EvalOutput.parse(response, fail_on_parse_error=fail_on_parse_error) 107 | for response in responses 108 | ] 109 | parse_failures = sum(1 for output in eval_outputs if output.score == -1) 110 | if save_results: 111 | self._save_results(eval_inputs, eval_outputs) 112 | if parse_failures > 0: 113 | logger.warning(f"Number of parsing failures: {parse_failures} out of {len(responses)}") 114 | 115 | return eval_outputs 116 | 117 | 118 | class AsyncFlowJudge(BaseFlowJudge): 119 | """Asynchronous FlowJudge class for evaluating AI outputs.""" 120 | 121 | def __init__( 122 | self, 123 | metric: Metric | CustomMetric, 124 | model: AsyncBaseFlowJudgeModel, 125 | output_dir: str | None = "output/", 126 | ): 127 | """Initialize AsyncFlowJudge with a metric and model.""" 128 | super().__init__(metric, model, output_dir) 129 | if not isinstance(model, AsyncBaseFlowJudgeModel): 130 | raise ValueError("Invalid model type. Use AsyncBaseFlowJudgeModel or its subclasses.") 131 | 132 | def _handle_batch_result( 133 | self, batch_result: BatchResult, batch_len: int, fail_on_parse_error: bool 134 | ) -> list[EvalOutput]: 135 | """Handle output parsing for batched results. 136 | 137 | Args: 138 | batch_result: The result of the batch from Baseten. 139 | batch_len: The initial batch size derived from the length of Eval Inputs. 140 | fail_on_parse_error: Flag to raise a parse error for the EvalOutput. 141 | 142 | Returns: 143 | list[EvalOutput]: A list of eval outputs with score and feedback. 144 | 145 | Note: 146 | There might be instances when downstream errors result in missing entries 147 | for the eval outputs. We implement retry strategies where we can, but in 148 | certain instances (such as network failures) errors are inevitable. 149 | To ascertain predictability, we 'fill-in' the errors with empty EvalOutputs. 150 | 151 | """ 152 | eval_outputs = [EvalOutput(feedback="BasetenError", score=None)] * batch_len 153 | for output in batch_result.successful_outputs: 154 | index = output.get("index") 155 | eval_outputs[index - 1] = EvalOutput.parse( 156 | response=output["response"], fail_on_parse_error=fail_on_parse_error 157 | ) 158 | 159 | # Log all downstream errors 160 | if len(batch_result.errors) > 0: 161 | logger.warning( 162 | f"Number of Baseten API errors: {len(batch_result.errors)}" 163 | f" of {batch_result.total_requests}." 164 | f" Success rate is {batch_result.success_rate}" 165 | " List of errors: " 166 | ) 167 | for error in batch_result.errors: 168 | logger.warning(f"{error.error_type}: {error.error_message}") 169 | 170 | return eval_outputs 171 | 172 | async def async_evaluate( 173 | self, eval_input: EvalInput, save_results: bool = False, append: bool = False 174 | ) -> EvalOutput | None: 175 | """Evaluate a single EvalInput object asynchronously.""" 176 | try: 177 | self._validate_inputs(eval_input) 178 | prompt = self._format_prompt(eval_input) 179 | result = await self.model._async_generate(prompt) 180 | response = result 181 | 182 | if isinstance(result, FlowJudgeError): 183 | logger.error(f" {result.error_type}: {result.error_message}") 184 | return 185 | 186 | eval_output = EvalOutput.parse(response) 187 | if save_results: 188 | logger.info(f"Saving result {'(append)' if append else '(overwrite)'}") 189 | await asyncio.to_thread( 190 | self._save_results, [eval_input], [eval_output], append=append 191 | ) 192 | return eval_output 193 | except Exception as e: 194 | logger.error(f"Asynchronous evaluation failed: {e}") 195 | raise 196 | 197 | # TODO: figure if we want to have the parser be passed the fail_on_parse_error flag 198 | async def async_batch_evaluate( 199 | self, 200 | eval_inputs: list[EvalInput], 201 | use_tqdm: bool = True, 202 | save_results: bool = True, 203 | append: bool = False, # Change default to False 204 | fail_on_parse_error: bool = False, 205 | ) -> list[EvalOutput]: 206 | """Batch evaluate a list of EvalInput objects asynchronously.""" 207 | self._validate_inputs(eval_inputs) 208 | prompts = [self._format_prompt(eval_input) for eval_input in eval_inputs] 209 | batch_result = await self.model._async_batch_generate(prompts, use_tqdm=use_tqdm) 210 | 211 | if isinstance(batch_result, BatchResult): 212 | eval_outputs = self._handle_batch_result( 213 | batch_result=batch_result, 214 | batch_len=len(eval_inputs), 215 | fail_on_parse_error=fail_on_parse_error, 216 | ) 217 | else: 218 | eval_outputs = [ 219 | EvalOutput.parse(response, fail_on_parse_error=fail_on_parse_error) 220 | for response in batch_result 221 | ] 222 | logger.warning(f"{eval_outputs}") 223 | parse_failures = sum(1 for output in eval_outputs if output.score and output.score == -1) 224 | 225 | if save_results: 226 | logger.info(f"Saving {len(eval_outputs)} results") 227 | for i, (eval_input, eval_output) in enumerate( 228 | zip(eval_inputs, eval_outputs, strict=True) 229 | ): 230 | await asyncio.to_thread( 231 | self._save_results, 232 | [eval_input], 233 | [eval_output], 234 | append=(append or i > 0), # Append for all but the first, unless append is True 235 | ) 236 | 237 | if parse_failures > 0: 238 | logger.warning( 239 | f"Number of parsing failures: {parse_failures} out of {len(eval_outputs)}" 240 | ) 241 | 242 | return eval_outputs 243 | -------------------------------------------------------------------------------- /flow_judge/integrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flowaicom/flow-judge/e56d199db4d79f184dac1e9ab2da83992acda14d/flow_judge/integrations/__init__.py -------------------------------------------------------------------------------- /flow_judge/integrations/haystack.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import numpy as np 5 | from haystack import component, default_from_dict, default_to_dict 6 | from haystack.utils import deserialize_type 7 | 8 | from flow_judge.flow_judge import EvalInput, EvalOutput, FlowJudge 9 | from flow_judge.metrics.metric import CustomMetric, Metric 10 | from flow_judge.models.common import BaseFlowJudgeModel 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | # Based on https://github.com/deepset-ai/haystack/blob/d234c75168dcb49866a6714aa232f37d56f72cab/haystack/components/evaluators/llm_evaluator.py#L354 15 | 16 | 17 | @component 18 | class HaystackFlowJudge: 19 | """A component that uses FlowJudge to evaluate inputs.""" 20 | 21 | def __init__( 22 | self, 23 | metric: Metric | CustomMetric, 24 | model: BaseFlowJudgeModel, 25 | output_dir: str = "output/", 26 | progress_bar: bool = True, 27 | raise_on_failure: bool = True, 28 | save_results: bool = True, 29 | fail_on_parse_error: bool = False, 30 | ): 31 | """Construct a new FlowJudge evaluator.""" 32 | if isinstance(metric, (Metric, CustomMetric)): 33 | self.metric = metric 34 | else: 35 | raise ValueError("Invalid metric type. Use Metric or CustomMetric.") 36 | 37 | if not isinstance(model, BaseFlowJudgeModel): 38 | raise ValueError("Invalid model type. Use BaseFlowJudgeModel or its subclasses.") 39 | 40 | self.model = model 41 | self.output_dir = output_dir 42 | 43 | self.judge = FlowJudge(metric=self.metric, model=self.model, output_dir=self.output_dir) 44 | 45 | # extract inputs and output from the metric 46 | self.inputs, self.outputs = self._extract_vars_from_metric(self.metric) 47 | self.validate_init_parameters(self.inputs, self.outputs) 48 | self.raise_on_failure = raise_on_failure 49 | self.progress_bar = progress_bar 50 | self.save_results = save_results 51 | self.fail_on_parse_error = fail_on_parse_error 52 | 53 | component.set_input_types(self, **dict(self.inputs)) 54 | 55 | @staticmethod 56 | def _extract_vars_from_metric( 57 | metric: Metric | CustomMetric, 58 | ) -> tuple[list[tuple[str, type[list]]], list[str]]: 59 | """Extract the inputs to the component and its type from the metric. 60 | 61 | It also sets the output of the component. 62 | """ 63 | eval_inputs_keys: list[str] = metric.required_inputs 64 | eval_output_key: str = metric.required_output 65 | 66 | inputs = [(key, list[str]) for key in eval_inputs_keys + [eval_output_key]] 67 | 68 | outputs = ["feedback", "score"] 69 | 70 | return inputs, outputs 71 | 72 | @staticmethod 73 | def validate_init_parameters(inputs: list[tuple[str, type[list]]], outputs: list[str]): 74 | """Validate the init parameters.""" 75 | # Validate inputs 76 | if ( 77 | not isinstance(inputs, list) 78 | or not all(isinstance(_input, tuple) for _input in inputs) 79 | or not all( 80 | isinstance(_input[0], str) and _input[1] is not list and len(_input) == 2 81 | for _input in inputs 82 | ) 83 | ): 84 | msg = ( 85 | f"FlowJudge evaluator expects inputs to \ 86 | be a list of tuples. Each tuple must contain an input name and " 87 | f"type of list but received {inputs}." 88 | ) 89 | raise ValueError(msg) 90 | 91 | # Validate outputs 92 | if not isinstance(outputs, list) or not all(isinstance(output, str) for output in outputs): 93 | msg = f"FlowJudge evaluator expects outputs \ 94 | to be a list of str but received {outputs}." 95 | raise ValueError(msg) 96 | 97 | @component.output_types( 98 | results=list[dict[str, Any]], 99 | metadata=dict[str, Any], 100 | score=float, 101 | individual_scores=list[float], 102 | ) 103 | def run(self, **inputs) -> dict[str, Any]: 104 | """Run the FlowJudge evaluator on the provided inputs.""" 105 | self._validate_input_parameters(dict(self.inputs), inputs) 106 | eval_inputs: list[EvalInput] = self._prepare_inputs(inputs=inputs, metric=self.metric) 107 | eval_outputs: list[EvalOutput] = self.judge.batch_evaluate( 108 | eval_inputs, 109 | save_results=self.save_results, 110 | fail_on_parse_error=self.fail_on_parse_error, 111 | ) 112 | 113 | results: list[dict[str, Any] | None] = [] 114 | parsing_errors = 0 115 | for eval_output in eval_outputs: 116 | if eval_output.score != -1: 117 | result = { 118 | "feedback": eval_output.feedback, 119 | "score": eval_output.score, 120 | } 121 | 122 | results.append(result) 123 | else: 124 | results.append({"feedback": eval_output.feedback, "score": eval_output.score}) 125 | parsing_errors += 1 126 | 127 | if parsing_errors > 0: 128 | msg = ( 129 | f"FlowJudge failed to parse {parsing_errors} results out " 130 | f"of {len(eval_outputs)}. Score and Individual Scores are " 131 | "based on the successfully parsed results." 132 | ) 133 | logger.warning(msg) 134 | 135 | metadata = self.model.metadata 136 | 137 | score = np.mean([result["score"] for result in results if result["score"] != -1]) 138 | individual_scores = [float(result["score"]) for result in results if result["score"] != -1] 139 | 140 | return { 141 | "results": results, 142 | "metadata": metadata, 143 | "score": score, 144 | "individual_scores": individual_scores, 145 | } 146 | 147 | @staticmethod 148 | def _validate_input_parameters(expected: dict[str, Any], received: dict[str, Any]) -> None: 149 | """Validate the input parameters.""" 150 | # Validate that all expected inputs are present in the received inputs 151 | for param in expected.keys(): 152 | if param not in received: 153 | msg = f"FlowJudge evaluator expected input \ 154 | parameter '{param}' but received only {received.keys()}." 155 | raise ValueError(msg) 156 | 157 | # Validate that all received inputs are lists 158 | if not all(isinstance(_input, list) for _input in received.values()): 159 | msg = ( 160 | "FlowJudge evaluator expects all input values to be lists but received " 161 | f"{[type(_input) for _input in received.values()]}." 162 | ) 163 | raise ValueError(msg) 164 | 165 | # Validate that all received inputs are of the same length 166 | inputs = received.values() 167 | length = len(next(iter(inputs))) 168 | if not all(len(_input) == length for _input in inputs): 169 | msg = ( 170 | f"FlowJudge evaluator expects all input lists\ 171 | to have the same length but received {inputs} with lengths " 172 | f"{[len(_input) for _input in inputs]}." 173 | ) 174 | raise ValueError(msg) 175 | 176 | @staticmethod 177 | def _prepare_inputs(inputs: dict[str, Any], metric: Metric | CustomMetric) -> list[EvalInput]: 178 | """Prepare the inputs for the flow judge.""" 179 | eval_inputs = [] 180 | num_samples = len(next(iter(inputs.values()))) 181 | 182 | for i in range(num_samples): 183 | input_list = [] 184 | output_dict = {} 185 | for key, value_list in inputs.items(): 186 | temp_dict = {} 187 | if key in metric.required_inputs: 188 | temp_dict[key] = value_list[i] 189 | input_list.append(temp_dict) 190 | elif key == metric.required_output: 191 | output_dict[key] = value_list[i] 192 | 193 | if not output_dict: 194 | raise ValueError(f"Required output '{metric.required_output}' not found in inputs.") 195 | 196 | eval_input = EvalInput(inputs=input_list, output=output_dict) 197 | eval_inputs.append(eval_input) 198 | 199 | return eval_inputs 200 | 201 | def to_dict(self) -> dict[str, Any]: 202 | """Serialize this component to a dictionary.""" 203 | return default_to_dict( 204 | self, 205 | metric=self.metric, 206 | model=self.model, 207 | output_dir=self.output_dir, 208 | progress_bar=self.progress_bar, 209 | raise_on_failure=self.raise_on_failure, 210 | save_results=self.save_results, 211 | fail_on_parse_error=self.fail_on_parse_error, 212 | ) 213 | 214 | @classmethod 215 | def from_dict(cls, data: dict[str, Any]) -> "HaystackFlowJudge": 216 | """Deserialize this component from a dictionary.""" 217 | data["init_parameters"]["inputs"] = [ 218 | (name, deserialize_type(type_)) for name, type_ in data["init_parameters"]["inputs"] 219 | ] 220 | 221 | return default_from_dict(cls, data) 222 | -------------------------------------------------------------------------------- /flow_judge/integrations/langchain.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections.abc import Sequence 3 | from typing import Any 4 | 5 | from langchain.evaluation import StringEvaluator 6 | 7 | from flow_judge import AsyncFlowJudge, EvalInput, FlowJudge 8 | from flow_judge.metrics import CustomMetric, Metric 9 | from flow_judge.models import AsyncBaseFlowJudgeModel, BaseFlowJudgeModel 10 | 11 | 12 | class FlowJudgeLangChainEvaluator(StringEvaluator): 13 | """FlowJudgeLangchainEvaluator is a custom evaluator for LangChain. 14 | 15 | It uses FlowJudge to evaluate the LLM outputs. 16 | """ 17 | 18 | def __init__( 19 | self, metric: Metric | CustomMetric, model: BaseFlowJudgeModel | AsyncBaseFlowJudgeModel 20 | ): 21 | """Initialize the LlamaIndexFlowJudge.""" 22 | if isinstance(metric, (Metric, CustomMetric)): 23 | self.metric = metric 24 | else: 25 | raise ValueError("Invalid metric type. Use Metric or CustomMetric.") 26 | 27 | # Validate model and choose appropriate FlowJudge class 28 | if isinstance(model, (BaseFlowJudgeModel, AsyncBaseFlowJudgeModel)): 29 | self.model = model 30 | else: 31 | raise ValueError( 32 | "The model must be an instance of BaseFlowJudgeModel or AsyncBaseFlowJudgeModel." 33 | ) 34 | 35 | # Determine if the model is async-capable 36 | self.is_async = hasattr(self.model, "exec_async") and self.model.exec_async 37 | 38 | # Initialize the appropriate judge based on async capability 39 | if self.is_async: 40 | self.judge = AsyncFlowJudge(metric=self.metric, model=self.model) 41 | else: 42 | self.judge = FlowJudge(metric=self.metric, model=self.model) 43 | 44 | def _prepare_eval_input( 45 | self, 46 | prediction: str, 47 | reference: str | None = None, 48 | input: str | None = None, 49 | **kwargs: Any, 50 | ) -> EvalInput: 51 | # Combine all inputs into a single dictionary 52 | all_inputs = {"prediction": prediction, "reference": reference, "input": input, **kwargs} 53 | 54 | # Prepare eval_inputs based on metric's required_inputs 55 | eval_inputs = [] 56 | for req_input in self.metric.required_inputs: 57 | if req_input in all_inputs: 58 | value = all_inputs[req_input] 59 | if isinstance(value, (list, Sequence)) and not isinstance(value, str): 60 | eval_inputs.extend([{req_input: v} for v in value]) 61 | else: 62 | eval_inputs.append({req_input: value}) 63 | 64 | # Prepare the output 65 | output_key = self.metric.required_output 66 | output_value = all_inputs.get( 67 | output_key, prediction 68 | ) # Default to prediction if not specified 69 | 70 | return EvalInput(inputs=eval_inputs, output={output_key: output_value}) 71 | 72 | def _evaluate_strings( 73 | self, 74 | prediction: str, 75 | reference: str | None = None, 76 | input: str | None = None, 77 | **kwargs: Any, 78 | ) -> dict[str, Any]: 79 | eval_input = self._prepare_eval_input(prediction, reference, input, **kwargs) 80 | result = self.judge.evaluate(eval_input, save_results=False) 81 | 82 | return { 83 | "score": result.score, 84 | "reasoning": result.feedback, 85 | } 86 | 87 | async def _aevaluate_strings( 88 | self, 89 | prediction: str, 90 | reference: str | None = None, 91 | input: str | None = None, 92 | sleep_time_in_seconds: int = 1, 93 | **kwargs: Any, 94 | ) -> dict[str, Any]: 95 | await asyncio.sleep(sleep_time_in_seconds) 96 | eval_input = self._prepare_eval_input(prediction, reference, input, **kwargs) 97 | result = await self.judge.async_evaluate(eval_input, save_results=False) 98 | 99 | return { 100 | "score": result.score, 101 | "reasoning": result.feedback, 102 | } 103 | 104 | @property 105 | def requires_input(self) -> bool: 106 | """Requires input.""" 107 | return "input" in self.metric.required_inputs 108 | 109 | @property 110 | def requires_reference(self) -> bool: 111 | """Requires reference.""" 112 | return "reference" in self.metric.required_inputs 113 | 114 | @property 115 | def evaluation_name(self) -> str: 116 | """Get metric name.""" 117 | return f"flow_judge_{self.metric.name}" 118 | 119 | def get_required_inputs(self) -> list[str]: 120 | """Get required inputs.""" 121 | return self.metric.required_inputs + [self.metric.required_output] 122 | -------------------------------------------------------------------------------- /flow_judge/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from flow_judge.metrics.metric import CustomMetric, Metric, RubricItem 2 | from flow_judge.metrics.presets import ( 3 | RESPONSE_CORRECTNESS_3POINT, 4 | RESPONSE_CORRECTNESS_5POINT, 5 | RESPONSE_CORRECTNESS_BINARY, 6 | RESPONSE_FAITHFULNESS_3POINT, 7 | RESPONSE_FAITHFULNESS_5POINT, 8 | RESPONSE_FAITHFULNESS_BINARY, 9 | RESPONSE_RELEVANCE_3POINT, 10 | RESPONSE_RELEVANCE_5POINT, 11 | RESPONSE_RELEVANCE_BINARY, 12 | list_all_metrics, 13 | ) 14 | -------------------------------------------------------------------------------- /flow_judge/metrics/metric.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class RubricItem(BaseModel): 5 | """Represents an item in the evaluation rubric.""" 6 | 7 | score: int 8 | description: str 9 | 10 | 11 | class Metric(BaseModel): 12 | """Represents an evaluation metric.""" 13 | 14 | name: str 15 | criteria: str 16 | rubric: list[RubricItem] 17 | required_inputs: list[str] | None = None 18 | required_output: str 19 | 20 | def print_required_keys(self): 21 | """Prints the required input and output keys.""" 22 | print(f"Metric: {self.name}") 23 | print("Required inputs:", ", ".join(self.required_inputs or [])) 24 | print("Required output:", self.required_output) 25 | 26 | 27 | class CustomMetric(Metric): 28 | """Represents a custom evaluation metric.""" 29 | 30 | def __init__( 31 | self, 32 | name: str, 33 | criteria: str, 34 | rubric: list[RubricItem], 35 | required_inputs: list[str], 36 | required_output: str, 37 | ): 38 | """Initialize a custom metric.""" 39 | super().__init__( 40 | name=name, 41 | criteria=criteria, 42 | rubric=rubric, 43 | required_inputs=required_inputs, 44 | required_output=required_output, 45 | ) 46 | -------------------------------------------------------------------------------- /flow_judge/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import AsyncBaseFlowJudgeModel, BaseFlowJudgeModel, ModelConfig, ModelType 2 | from .huggingface import Hf, HfError 3 | from .llamafile import Llamafile, LlamafileError 4 | from .vllm import Vllm, VllmError 5 | 6 | __all__ = [ 7 | "AsyncBaseFlowJudgeModel", 8 | "BaseFlowJudgeModel", 9 | "ModelType", 10 | "ModelConfig", 11 | "Hf", 12 | "HfError", 13 | "Vllm", 14 | "VllmError", 15 | "Llamafile", 16 | "LlamafileError", 17 | ] 18 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | 5 | class BaseAPIAdapter(ABC): 6 | """Base adapter layer for making remote requests to hosted models.""" 7 | 8 | def __init__(self, base_url: str): 9 | """Initialize the BaseAPIAdapter. 10 | 11 | : param: base_url: The baseten deployed model url 12 | """ 13 | self.base_url = base_url 14 | 15 | @abstractmethod 16 | def _fetch_response(self, request_body: dict[str, Any]) -> str: 17 | """Generate a response based on the given request.""" 18 | pass 19 | 20 | @abstractmethod 21 | def _fetch_batched_response(self, request_bodies: list[dict[str, Any]]) -> list[str]: 22 | """Generate responses for multiple requests.""" 23 | pass 24 | 25 | 26 | class AsyncBaseAPIAdapter(ABC): 27 | """Base adapter layer for making remote requests to hosted models.""" 28 | 29 | def __init__(self, base_url: str): 30 | """Initialize the AsyncBaseAPIAdapter. 31 | 32 | : param: base_url: The baseten deployed model url 33 | """ 34 | self.base_url = base_url 35 | 36 | @abstractmethod 37 | async def _async_fetch_response(self, request_body: dict[str, Any]) -> str: 38 | """Generate a response based on the given request.""" 39 | pass 40 | 41 | @abstractmethod 42 | async def _async_fetch_batched_response( 43 | self, request_bodies: list[dict[str, Any]] 44 | ) -> list[str]: 45 | """Generate responses for multiple requests.""" 46 | pass 47 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/README.md: -------------------------------------------------------------------------------- 1 | # Baseten remote execution 2 | 3 | ### Running Remotely on Baseten 4 | 5 | Running Flow Judge on Baseten allows you to offload processing tasks from your local machine. 6 | Using remote execution allows for running generations in parallel, multiple requests at a time. 7 | It can significantly improve throughput and reduce overall wait times, which might be useful for larger workloads. 8 | 9 | ### Execution modes 10 | 11 | - **Sync Mode**: 12 | - For smaller jobs, sync mode provides immediate feedback and is perfect for quick iterations and exploratory tasks 13 | 14 | - **Async (Batch) Mode**: 15 | - Allows to submit multiple requests simultaneously, ideal for larger workloads. It allows for parallel processing of multiple requests, making it suitable for large-scale 16 | evaluations. 17 | 18 | ## Setup 19 | 20 | ### Installation 21 | 22 | To use Baseten integration, ensure you install the optional `baseten` dependency: 23 | 24 | ```bash 25 | pip install -e .[dev,baseten] 26 | ``` 27 | 28 | ### Creating Baseten account 29 | 30 | Follow the instruction on the signup page: https://app.baseten.co/signup 31 | 32 | ### Baseten API key 33 | 34 | To use Baseten, you need to create an API key in your account. 35 | Navigate to your account settings to generate and manage your API keys: 36 | https://app.baseten.co/settings/account/api_keys 37 | 38 | ### Baseten webhook secret (optional) 39 | 40 | For async (batch) execution, it is required to create webhook secret in your Baseten account. 41 | Follow the official Baseten instructions: 42 | https://docs.baseten.co/invoke/async-secure#creating-webhook-secrets 43 | 44 | ### Baseten GPU 45 | 46 | The Flow Judge model can be deployed with A10G or H100 40GB on Baseten's infrastructure. 47 | You have an option to set the `BASETEN_GPU` environment variable using either `A10G` or `H100` as the value in a notebook environment. 48 | If in an interactive environment (CLI) you will be asked if you would like to switch to H100. 49 | The FlowJudge models are then selected based on the architecture and GPU selection: 50 | 51 | A10G -> Flow-Judge-v0.1-AWQ 52 | 53 | H100 -> Flow-Judge-v0.1-FP8 54 | 55 | ## Sync execution 56 | 57 | When running Flow Judge eval, use `Baseten` class as a model 58 | (see [Quickstart](https://github.com/flowaicom/flow-judge?tab=readme-ov-file#quick-start)): 59 | 60 | ```python 61 | model=Baseten() 62 | ``` 63 | 64 | That's it! :) 65 | 66 | During the first run you will be asked to provide the Baseten API key and the GPU. 67 | The key will be stored on your computer in `~/.trussrc` (see [truss](https://docs.baseten.co/truss-reference/overview)). 68 | It is used to validate if the model is already deployed and deploy it if needed. 69 | 70 | As part of the execution process we deploy [Flow Judge model](https://huggingface.co/flowaicom/Flow-Judge-v0.1-AWQ) to 71 | your Baseten account and promote it to published 72 | deployment. 73 | 74 | ## Async Execution 75 | 76 | ### Description 77 | 78 | In async mode, Baseten sends model outputs to a specified webhook address that needs to be publicly accessible on the 79 | internet. 80 | 81 | We offer a free proxy as an effortless solution that can forward responses directly to the Flow Judge, 82 | without exposing endpoint from your computer to the internet. 83 | This is the default behavior when running in batch mode. 84 | 85 | Alternatively, you can use tools like `ngrok` or `localtunnel` to expose the [same proxy](https://github.com/flowaicom/webhook-proxy) running locally on your device 86 | next to the Flow Judge. 87 | 88 | ### Flow AI proxy (hosted) 89 | 90 | For the hosted proxy there's no additional setup needed. You only need to configure the model instance to work in the 91 | async mode: 92 | 93 | ```python 94 | model=Baseten(exec_async=True, webhook_proxy_url="https://proxy.flow-ai.dev") 95 | ``` 96 | 97 | Similarly to the synchronous execution, during the first run you will be asked for the API key to your Baseten account and the GPU, 98 | unless provided earlier. The key is used to validate and deploy the model to your Baseten account. 99 | 100 | Additionally, when using asynchronous execution, we verify the signature of the received webhooks payloads, as 101 | recommended by Baseten (see [official documentation](https://docs.baseten.co/invoke/async-secure)). 102 | Because of that, during first run you will be asked to provide the webhook secret. It will be stored 103 | in `~/.config/flow-judge/baseten_webhook_secret` and never leaves your device. 104 | 105 | ### Using proxy locally 106 | 107 | Currently Flow Judge does not provide a standalone endpoint to expose to the Internet. Instead, you can run an instance 108 | of our proxy on your machine and expose it's endpoint using eg. `ngrok`. 109 | 110 | 1. Download pre-built binary from the proxy releases page: https://github.com/flowaicom/webhook-proxy/releases 111 | or build it according to the instructions provided in the repository. 112 | 2. Run the proxy: 113 | ```shell 114 | ./proxy -addr=0.0.0.0:8000 115 | ``` 116 | For more options and detailed instructions see the documentation provided in the proxy repository. 117 | 3. Expose the running proxy to the internet with eg. `ngrok`: 118 | ```shell 119 | ngrok localhost:8000 120 | ``` 121 | The output of this command will provide you with the public URL. 122 | 4. Use the public URL when setting up the model instance in your Flow Judge implementation: 123 | ```python 124 | model=Baseten(exec_async=True, webhook_proxy_url="https://«ngrok url»") 125 | ``` 126 | 127 | #### Using Docker 128 | 129 | In addition to the pre-built binaries, we also provide proxy Docker images. 130 | 131 | Run the proxy with Docker: 132 | 133 | ```shell 134 | docker pull ghcr.io/flowaicom/webhook-proxy:latest 135 | docker run --name=flowai-proxy -d -p 8000:8000 ghcr.io/flowaicom/webhook-proxy:latest 136 | ``` 137 | 138 | Then continue with the process from point 3. from the list above. 139 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/api_auth.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import http 3 | import logging 4 | import os 5 | 6 | import requests 7 | import truss 8 | from truss.remote import remote_factory 9 | 10 | from .util import is_interactive 11 | 12 | logger: logging.Logger = logging.getLogger(__name__) 13 | 14 | 15 | def get_baseten_api_key() -> str | None: 16 | """Retrieve the Baseten API key from environment or config file. 17 | 18 | :return: The Baseten API key if found, None otherwise. 19 | :rtype: str | None 20 | """ 21 | logger.debug("Attempting to retrieve Baseten API key") 22 | api_key: str | None = os.environ.get("BASETEN_API_KEY") 23 | if api_key: 24 | logger.debug("API key found in environment variables") 25 | if len(api_key) != 41: 26 | logger.warning( 27 | "Warning: Baseten API key might be incorrect. " 28 | "The length should be exactly 41 characters." 29 | ) 30 | return api_key 31 | 32 | logger.debug("API key not found in environment, checking remote config") 33 | try: 34 | c: remote_factory.RemoteConfig | None = remote_factory.RemoteFactory.load_remote_config( 35 | "baseten" 36 | ) 37 | if c is not None: 38 | api_key = c.configs.get("api_key") 39 | if api_key: 40 | logger.debug("API key found in remote config") 41 | if len(api_key) != 41: 42 | logger.warning( 43 | "Warning: Baseten API key might be incorrect. " 44 | "The length should be exactly 41 characters." 45 | ) 46 | os.environ["BASETEN_API_KEY"] = api_key 47 | return api_key 48 | except Exception as e: 49 | logger.error(f"Error loading remote config: {e}") 50 | 51 | logger.debug("API key not found") 52 | return None 53 | 54 | 55 | def _validate_auth_status(api_key: str | None = None) -> bool: 56 | """Validate the authentication status with Baseten. 57 | 58 | :param api_key: Optional API key. If None, it will try to obtain API key from env or .trussrc. 59 | :return: True if authentication is successful, False otherwise. 60 | :rtype: bool 61 | """ 62 | logger.debug("Validating authentication status") 63 | 64 | api_key: str | None = api_key or get_baseten_api_key() 65 | if not api_key: 66 | logger.warning("No API key available for authentication") 67 | return False 68 | 69 | try: 70 | response: requests.Response = requests.get( 71 | "https://api.baseten.co/v1/models", 72 | headers={"Authorization": f"Api-Key {api_key}"}, 73 | timeout=10, 74 | ) 75 | is_valid: bool = response.status_code == http.HTTPStatus.OK 76 | logger.debug(f"Authentication status: {'valid' if is_valid else 'invalid'}") 77 | return is_valid 78 | except requests.RequestException as e: 79 | logger.error(f"Error validating auth status: {e}") 80 | return False 81 | 82 | 83 | def _attempt_truss_auth_with_env_api_key() -> None: 84 | """Authenticate Truss with the API key from environment variable. 85 | 86 | :return: None 87 | """ 88 | logger.debug("Authenticating Truss with environment API key") 89 | env_api_key: str | None = os.environ.get("BASETEN_API_KEY") 90 | if not env_api_key: 91 | logger.debug("Baseten API key not available in env var") 92 | return 93 | 94 | logger.debug("Logging in to Truss with environment API key") 95 | truss.login(env_api_key) 96 | 97 | 98 | def _print_noninteractive_prompt() -> None: 99 | """Print a API key prompt for non-interactive environments.""" 100 | print("Set the Baseten API key in `BASETEN_API_KEY` environment variable " "and run again:") 101 | print("```") 102 | print('os.environ["BASETEN_API_KEY"] = "«your API key»"') 103 | print("```") 104 | 105 | 106 | def _print_general_prompt() -> None: 107 | """Print a API key prompt information.""" 108 | print("To run Flow Judge remotely with Baseten, signup and generate API key") 109 | print(" ➡️ Signup: https://app.baseten.co/signup") 110 | print(" ➡️ API keys: https://app.baseten.co/settings/api_keys") 111 | print(" ➡️ Docs: https://docs.baseten.co/quickstart#setup\n") 112 | 113 | 114 | def _validate_entered_key(key: str) -> bool: 115 | """Checks if the key entered by the user in interactive environments is valid. 116 | 117 | :param key: The key entered by the user 118 | :return: True if the key entered by the user is valid, False otherwise. 119 | :rtype: bool 120 | """ 121 | try: 122 | logger.debug("Attempting to log in with provided API key") 123 | truss.login(key) 124 | except Exception as e: 125 | logger.error(f"Error during login: {e}") 126 | print("An error occurred during login. Please try again.") 127 | return False 128 | 129 | if not _validate_auth_status(key): 130 | logger.warning("Invalid Baseten API key") 131 | print("Invalid Baseten API key, try again.") 132 | return False 133 | 134 | logger.info("Baseten authentication successful") 135 | return True 136 | 137 | 138 | def ensure_baseten_authentication() -> bool: 139 | """Attempts to obtain and validate the Baseten API key from the user. 140 | 141 | :return: True if authentication is successful, False otherwise. 142 | """ 143 | # Checks if truss is already authenticated or if the key is provided in the env to authenticate 144 | _attempt_truss_auth_with_env_api_key() 145 | 146 | # Checks if authentication above succeeded 147 | if _validate_auth_status(): 148 | logger.info("Baseten authenticated") 149 | return True 150 | 151 | logger.warning("Baseten authentication failed or not initialized") 152 | _print_general_prompt() 153 | 154 | # In non-interactive environments we return early because their inputs (env vars) 155 | # are processed at the very beginning of this function. 156 | if not is_interactive(): 157 | logger.info("Non-interactive environment detected") 158 | _print_noninteractive_prompt() 159 | return False 160 | 161 | logger.info("Prompting for API key in interactive environment") 162 | while True: 163 | key: str = getpass.getpass("Baseten API key (hidden): ") 164 | if not key: 165 | logger.warning("Empty API key entered") 166 | print("Input is empty, try again.") 167 | continue 168 | 169 | print(key) 170 | 171 | if len(key) != 41: 172 | logger.warning( 173 | "Warning: Baseten API key might be incorrect. " 174 | "The length should be exactly 41 characters." 175 | ) 176 | 177 | if _validate_entered_key(key): 178 | logger.info("Login successful") 179 | return True 180 | 181 | return False 182 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/data_io.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, field_validator 2 | from typing_extensions import TypedDict 3 | 4 | from flow_judge.models.adapters.baseten.errors import FlowJudgeError 5 | 6 | 7 | class Message(TypedDict): 8 | """Represents a single request message for the Baseten API. 9 | 10 | Note: 11 | This class uses TypedDict for strict type checking. Ensure all fields 12 | are provided when instantiating. The 'id' field is crucial for tracking 13 | and error reporting throughout the evaluation process. 14 | 15 | Warning: 16 | Do not include sensitive information in the 'content' field, as it may 17 | be logged or stored for debugging purposes. 18 | """ 19 | 20 | id: str 21 | index: int 22 | prompt: str 23 | response: str 24 | 25 | 26 | class BatchResult(BaseModel): 27 | """Represents the result of a batch evaluation process. 28 | 29 | This class contains both successful outputs and errors encountered during the 30 | evaluation process, as well as metadata about the batch operation. 31 | 32 | Attributes: 33 | successful_outputs (List[Message]): List of successful evaluation outputs. 34 | errors (List[FlowJudgeError]): List of errors encountered during evaluation. 35 | total_requests (int): Total number of requests processed in the batch. 36 | success_rate (float): Rate of successful evaluations (0.0 to 1.0). 37 | 38 | Note: 39 | The success_rate is calculated as (len(successful_outputs) / total_requests). 40 | Be cautious when interpreting results with a low success rate, as it may 41 | indicate systemic issues with the evaluation process or input data. 42 | """ 43 | 44 | successful_outputs: list[Message] = Field( 45 | default_factory=list, description="List of successful evaluation outputs" 46 | ) 47 | errors: list[FlowJudgeError] = Field( 48 | default_factory=list, description="List of errors encountered during evaluation" 49 | ) 50 | total_requests: int = Field(..., description="Total number of requests processed") 51 | success_rate: float = Field(..., description="Rate of successful evaluations") 52 | 53 | @field_validator("total_requests") 54 | @classmethod 55 | def check_positive_total_requests(cls, v): 56 | """Placeholder.""" 57 | if v < 0: 58 | raise ValueError("total_requests must be positive") 59 | return v 60 | 61 | @field_validator("success_rate") 62 | @classmethod 63 | def check_success_rate_range(cls, v): 64 | """Placeholder.""" 65 | if not 0 <= v <= 1: 66 | raise ValueError("success_rate must be between 0 and 1") 67 | return v 68 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/deploy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import TypedDict 4 | 5 | import requests 6 | import truss 7 | from truss.api.definitions import ModelDeployment 8 | from truss.remote.baseten.error import ApiError 9 | 10 | from flow_judge.models.adapters.baseten.management import sync_set_scale_down 11 | 12 | from .api_auth import ensure_baseten_authentication, get_baseten_api_key 13 | from .gpu import ensure_gpu 14 | from .webhook import ensure_baseten_webhook_secret 15 | 16 | logger: logging.Logger = logging.getLogger(__name__) 17 | 18 | 19 | class ModelInfo(TypedDict): 20 | """Type definition for model information returned by Baseten API. 21 | 22 | :ivar id: The unique identifier of the model. 23 | :ivar name: The name of the model. 24 | """ 25 | 26 | id: str 27 | name: str 28 | 29 | 30 | def _initialize_model() -> bool: 31 | """Initialize the Flow Judge model. 32 | 33 | :return: True if initialization is successful, False otherwise. 34 | :rtype: bool 35 | """ 36 | logger.info("Initializing Flow Judge model") 37 | if _is_flowjudge_deployed(): 38 | logger.info("Flow Judge already deployed") 39 | return True 40 | 41 | if not ensure_gpu(): 42 | logger.error("BASETEN_GPU environment variable is required." " Set one of: H100 or A10G") 43 | return False 44 | 45 | truss_path: str = f"{os.path.dirname(os.path.realpath(os.path.abspath(__file__)))}/deployment" 46 | logger.debug(f"Truss path: {truss_path}") 47 | try: 48 | deployment: ModelDeployment = truss.push( 49 | truss_path, promote=True, trusted=True, environment=None 50 | ) 51 | logger.debug("Waiting for deployment to become active") 52 | deployment.wait_for_active() 53 | logger.info("Flow Judge Baseten deployment successful") 54 | 55 | model_id = get_deployed_model_id() 56 | api_key = get_baseten_api_key() 57 | if model_id and api_key: 58 | has_updated_scale_down = sync_set_scale_down( 59 | scale_down_delay=120, api_key=api_key, model_id=model_id 60 | ) 61 | if has_updated_scale_down: 62 | logger.info( 63 | "Successfully updated Baseten deployed model scale down delay to 2 mins." 64 | ) 65 | else: 66 | logger.info( 67 | "Unable to update Baseten deployed model scale down delay period." 68 | " Continuing with default" 69 | ) 70 | 71 | return True 72 | except ApiError as e: 73 | logger.error( 74 | "Flow Judge Baseten deployment failed. " 75 | f"Ensure that provided API key is correct and try again. {e.message}" 76 | ) 77 | except ValueError as e: 78 | logger.error( 79 | "Flow Judge Baseten deployment failed. " 80 | f"Ensure that provided API key is correct and try again. {str(e)}" 81 | ) 82 | return False 83 | 84 | 85 | def _get_models() -> list[ModelInfo] | None: 86 | """Fetch the list of models from Baseten API. 87 | 88 | :return: List of ModelInfo if successful, None otherwise. 89 | :rtype: Optional[List[ModelInfo]] 90 | """ 91 | logger.debug("Fetching models from Baseten API") 92 | api_key: str | None = get_baseten_api_key() 93 | if not api_key: 94 | logger.warning("No API key available to fetch models") 95 | return None 96 | 97 | try: 98 | response: requests.Response = requests.get( 99 | "https://api.baseten.co/v1/models", 100 | headers={"Authorization": f"Api-Key {api_key}"}, 101 | timeout=10, 102 | ) 103 | response.raise_for_status() 104 | resp: dict[str, list[ModelInfo]] = response.json() 105 | models: list[ModelInfo] | None = resp.get("models") 106 | logger.debug(f"Fetched {len(models) if models else 0} models") 107 | return models 108 | except requests.RequestException as e: 109 | logger.error(f"Error fetching models: {e}") 110 | return None 111 | 112 | 113 | def _is_flowjudge_deployed() -> bool: 114 | """Check if Flow Judge is already deployed. 115 | 116 | :return: True if Flow Judge is deployed, False otherwise. 117 | :rtype: bool 118 | """ 119 | logger.debug("Checking if Flow Judge is deployed") 120 | models: list[ModelInfo] | None = _get_models() 121 | 122 | if models is None: 123 | logger.warning("Unable to determine if Flow Judge is deployed") 124 | return False 125 | 126 | is_deployed: bool = any("Flow-Judge" in model["name"] for model in models) 127 | logger.debug(f"Flow Judge deployed: {is_deployed}") 128 | return is_deployed 129 | 130 | 131 | def get_deployed_model_id() -> str | None: 132 | """Get the ID of the deployed Flow Judge model. 133 | 134 | :return: The model ID if found, None otherwise. 135 | :rtype: Optional[str] 136 | """ 137 | logger.debug("Getting deployed Flow Judge model ID") 138 | models: list[ModelInfo] | None = _get_models() 139 | if not models: 140 | logger.warning("No models found") 141 | return None 142 | 143 | for model in models: 144 | if "Flow-Judge" in model["name"]: 145 | logger.debug(f"Found Flow Judge model with ID: {model['id']}") 146 | return model["id"] 147 | 148 | logger.warning("Flow Judge model not found") 149 | return None 150 | 151 | 152 | def ensure_model_deployment() -> bool: 153 | """Ensure Flow Judge model deployment to Baseten. 154 | 155 | :return: True if deployment is successful, False otherwise. 156 | :rtype: bool 157 | """ 158 | logger.info("Ensuring Flow Judge model deployment") 159 | if not ensure_baseten_authentication(): 160 | logger.error("Baseten not authenticated, interrupting model deployment") 161 | return False 162 | 163 | ensure_baseten_webhook_secret(optional=True) 164 | 165 | return _initialize_model() 166 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/deployment/config.yaml: -------------------------------------------------------------------------------- 1 | base_image: 2 | image: baseten/truss-server-base:3.11-gpu-v0.9.0 3 | environment_variables: 4 | CUDA_HOME: /usr/local/cuda 5 | LD_LIBRARY_PATH: /usr/local/cuda/extras/CUPTI/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 6 | PATH: /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin 7 | VLLM_ATTENTION_BACKEND: FLASH_ATTN 8 | external_package_dirs: [] 9 | model_metadata: 10 | example_model_input: 11 | messages: 12 | - content: "# GOAL\nYour job is to evaluate a task carried out by an AI system\ 13 | \ powered by a large language model.\n\nYou will be provided with the inputs\ 14 | \ and output of the task, as well as the evaluation criteria and scoring rubric.\ 15 | \ Your task is to evaluate the output of the AI system based on the evaluation\ 16 | \ criteria and scoring rubric provided.\n\n# INPUT\nBelow are the inputs required\ 17 | \ for performing the task:\n\n\nPlease read the technical issue\ 18 | \ that the user is facing and help me create a detailed solution based on\ 19 | \ the context provided.\n\n\n# Customer Issue: I'm having\ 20 | \ trouble pushing large files to my Git repository. It says 'File size exceeds\ 21 | \ GitHub's file size limit of 100 MB'.\n\n# Context: Git Large File Storage\ 22 | \ (LFS) replaces large files such as audio samples, videos, datasets, and\ 23 | \ graphics with text pointers inside Git, while storing the file contents\ 24 | \ on a remote server like GitHub.com or GitHub Enterprise.\n\nTo use Git LFS,\ 25 | \ you need to download and install the Git command line extension. For more\ 26 | \ information, see 'Installing Git Large File Storage'.\n\nOnce Git LFS is\ 27 | \ installed, you need to set up Git LFS for your user account by running:\n\ 28 | \n$ git lfs install\n\n\n\n# OUTPUT\nBelow is the output\ 29 | \ of the task:\n\n\nTo resolve your issue with pushing large\ 30 | \ files to your Git repository, you should use Git Large File Storage (LFS).\ 31 | \ Here's a step-by-step solution:\n\n1. Install Git LFS by following the instructions\ 32 | \ in the 'Installing Git Large File Storage' guide.\n\n2. Set up Git LFS for\ 33 | \ your user account by running:\n ```\n git lfs install\n ```\n\n3. In your\ 34 | \ repository, track the large files you want to push using:\n ```\n git lfs\ 35 | \ track \"*.large-file-extension\"\n ```\n Replace 'large-file-extension'\ 36 | \ with the actual extension of your large files.\n\n4. Add the .gitattributes\ 37 | \ file to your repository:\n ```\n git add .gitattributes\n ```\n\n5. Now,\ 38 | \ add and commit your large files as usual:\n ```\n git add large-file.ext\n\ 39 | \ git commit -m \"Add large file\"\n ```\n\n6. Push your changes:\n ```\n\ 40 | \ git push origin main\n ```\n\nBy using Git LFS, you'll be able to push files\ 41 | \ larger than GitHub's 100 MB limit, as LFS will handle them appropriately.\n\ 42 | \n\n\n# EVALUATION CRITERIA AND SCORING RUBRIC\nHere are\ 43 | \ the evaluation criteria and the rubric that you need to use for evaluating\ 44 | \ the task:\n\nBased on the given context, evaluate how\ 45 | \ consistent and faithful the generated response is to the context. The response\ 46 | \ should not contain any hallucinated or fabricated information that is not\ 47 | \ supported by the context.\n\n\n\n\ 48 | - Score 1: The response is completely inconsistent with the provided context.\ 49 | \ It contains significant amount of hallucinated or fabricated information\ 50 | \ that directly contradicts or is not supported at all by the context.\n-\ 51 | \ Score 2: The response is mostly inconsistent with the provided context.\ 52 | \ While it may contain some information from the context, it introduces a\ 53 | \ substantial amount of hallucinated or fabricated details that deviate from\ 54 | \ the context.\n- Score 3: The response is somewhat consistent with the provided\ 55 | \ context. It includes a mix of information from the context and some hallucinated\ 56 | \ or fabricated details. The fabrications are minor and do not significantly\ 57 | \ contradict the context.\n- Score 4: The response is mostly consistent with\ 58 | \ the provided context. The vast majority of the content is supported by the\ 59 | \ context, with only minor and inconsequential inconsistencies or fabrications,\ 60 | \ if any.\n- Score 5: The response is completely consistent with and faithful\ 61 | \ to the provided context. All details in the response are directly supported\ 62 | \ by the context, without any hallucinated or fabricated information.\n\n\ 63 | \n# INSTRUCTIONS FOR THE EVALUATION\n1. Understand the task and criteria:\ 64 | \ Familiarize yourself with the task to be evaluated. Review the evaluation\ 65 | \ criteria and scoring rubric to understand the different levels of performance\ 66 | \ and the descriptions for each score.\n2. Review the inputs and output: Look\ 67 | \ at the inputs provided for the task. Examine the output generated from completing\ 68 | \ the task.\n3. Compare output to score descriptions: Compare the output against\ 69 | \ the criteria and score descriptions in the scoring rubric. For each criterion,decide\ 70 | \ which description best matches the output.\n4. After comparing the output\ 71 | \ to the score descriptions, pay attention to the small details that might\ 72 | \ impact the final score that you assign. Sometimes a small difference can\ 73 | \ dictate the final score.\n5. Write verbal feedback justifying your evaluation\ 74 | \ that includes a detailed rationale, referring to specific aspects of the\ 75 | \ output and comparing them to the rubric.\n6. Assign a final score based\ 76 | \ on the scoring rubric.\n\n## FORMAT FOR THE EVALUATION\n- Write the verbal\ 77 | \ feedback inside tags without any additional surrounding text.\n\ 78 | - Write the numeric score inside tags, without any additional surrounding\ 79 | \ text and always after the feedback.\n\nPlease accurately evaluate the task.\ 80 | \ Strictly adhere to the evaluation criteria and rubric." 81 | role: user 82 | openai_compatible: true 83 | repo_id: flowaicom/Flow-Judge-v0.1-AWQ 84 | vllm_config: 85 | max_model_len: 8192 86 | tensor_parallel_size: 1 87 | model_name: Flow-Judge-v0.1 88 | python_version: py311 89 | repo_id: flowaicom/Flow-Judge-v0.1-AWQ 90 | requirements: 91 | - vllm>=0.6.2 92 | - vllm-flash-attn 93 | resources: 94 | accelerator: A10G 95 | use_gpu: true 96 | runtime: 97 | predict_concurrency: 128 98 | secrets: 99 | hf_access_token: hf_xyz 100 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/deployment/model/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | # Configure logging 5 | logging.basicConfig( 6 | level=logging.INFO, 7 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 8 | handlers=[logging.StreamHandler(sys.stdout)], 9 | ) 10 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/deployment/model/helper.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | import threading 5 | 6 | import httpx 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | DEFAULT_HEALTH_CHECK_INTERVAL = 5 # seconds 11 | 12 | 13 | def log_subprocess_output(process): 14 | """Process logs for the vLLM subprocess.""" 15 | while True: 16 | output = process.stdout.readline() 17 | if output == "" and process.poll() is not None: 18 | break 19 | if output: 20 | logger.info(f"vLLM subprocess stdout: {output.strip()}") 21 | rc = process.poll() 22 | if rc != 0: 23 | for error_output in process.stderr.readlines(): 24 | logger.error(f"vLLM subprocess stderr: {error_output.strip()}") 25 | 26 | 27 | async def monitor_vllm_server_health(vllm_server_url, health_check_interval): 28 | """Health check for the vLLM server.""" 29 | assert vllm_server_url is not None, "vllm_server_url must not be None" 30 | try: 31 | async with httpx.AsyncClient() as client: 32 | while True: 33 | response = await client.get(f"{vllm_server_url}/health") 34 | if response.status_code != 200: 35 | raise RuntimeError("vLLM is unhealthy") 36 | await asyncio.sleep(health_check_interval) 37 | except Exception as e: 38 | logging.error( 39 | f"vLLM has gone into an unhealthy state due to error: {e}, restarting service now..." 40 | ) 41 | os._exit(1) 42 | 43 | 44 | async def monitor_vllm_engine_health(vllm_engine, health_check_interval): 45 | """Health check for the vLLM engine.""" 46 | assert vllm_engine is not None, "vllm_engine must not be None" 47 | try: 48 | while True: 49 | await vllm_engine.check_health() 50 | await asyncio.sleep(health_check_interval) 51 | except Exception as e: 52 | logging.error( 53 | f"vLLM has gone into an unhealthy state due to error: {e}, restarting service now..." 54 | ) 55 | os._exit(1) 56 | 57 | 58 | def run_background_vllm_health_check( 59 | use_openai_compatible_server=False, 60 | health_check_interval=DEFAULT_HEALTH_CHECK_INTERVAL, 61 | vllm_engine=None, 62 | vllm_server_url=None, 63 | ): 64 | """Background process for vLLM health checks.""" 65 | logger.info("Starting background health check loop") 66 | loop = asyncio.new_event_loop() 67 | if use_openai_compatible_server: 68 | loop.create_task(monitor_vllm_server_health(vllm_server_url, health_check_interval)) 69 | else: 70 | loop.create_task(monitor_vllm_engine_health(vllm_engine, health_check_interval)) 71 | thread = threading.Thread(target=loop.run_forever, daemon=True) 72 | thread.start() 73 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/deployment/model/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import subprocess 4 | import threading 5 | import time 6 | import uuid 7 | 8 | import httpx 9 | from transformers import AutoTokenizer 10 | from vllm import SamplingParams 11 | from vllm.engine.arg_utils import AsyncEngineArgs 12 | from vllm.engine.async_llm_engine import AsyncLLMEngine 13 | 14 | # isort requires relative imports that are not accepted by Baseten 15 | from model.helper import log_subprocess_output, run_background_vllm_health_check # isort: skip 16 | 17 | MAX_LENGTH = 1024 18 | TEMPERATURE = 0.1 19 | TOP_P = 0.95 20 | TOP_K = 40 21 | DO_SAMPLE = True 22 | DEFAULT_STREAM = False 23 | 24 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class Model: 30 | """Baseten model class for deployment.""" 31 | 32 | # 25 minutes; the reason this would take this long is mostly if we download a large model 33 | MAX_FAILED_SECONDS = 1500 34 | HEALTH_CHECK_INTERVAL = 5 # seconds 35 | 36 | def __init__(self, **kwargs): 37 | """Initialize Baseten model deployment class.""" 38 | self._config = kwargs["config"] 39 | self.model_id = None 40 | self.llm_engine = None 41 | self.model_args = None 42 | self.hf_secret_token = kwargs["secrets"].get("hf_access_token", None) 43 | self.openai_compatible = self._config["model_metadata"].get("openai_compatible", False) 44 | self.vllm_base_url = None 45 | os.environ["HF_TOKEN"] = self.hf_secret_token 46 | 47 | def _load_openai_compatible_model(self): 48 | """Load OpenAI compatible model.""" 49 | self._client = httpx.AsyncClient(timeout=None) 50 | command = ["vllm", "serve", self._model_repo_id] 51 | for key, value in self._vllm_config.items(): 52 | if value is True: 53 | command.append(f"--{key.replace('_', '-')}") 54 | elif value is False: 55 | continue 56 | else: 57 | command.append(f"--{key.replace('_', '-')}") 58 | command.append(str(value)) 59 | 60 | logger.info(f"Starting openai compatible vLLM server with command: {command}") 61 | 62 | self._vllm_process = subprocess.Popen( 63 | command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True 64 | ) 65 | 66 | output_thread = threading.Thread(target=log_subprocess_output, args=(self._vllm_process,)) 67 | output_thread.daemon = True 68 | output_thread.start() 69 | 70 | # Wait for 10 seconds and check if command fails 71 | time.sleep(10) 72 | 73 | if self._vllm_process.poll() is None: 74 | logger.info("Command to start vLLM server ran successfully") 75 | else: 76 | stdout, stderr = self._vllm_process.communicate() 77 | if self._vllm_process.returncode != 0: 78 | logger.error(f"Command failed with error: {stderr}") 79 | raise RuntimeError( 80 | f"Command failed with code {self._vllm_process.returncode}: {stderr}" 81 | ) 82 | 83 | if self._vllm_config and "port" in self._vllm_config: 84 | self._vllm_port = self._vllm_config["port"] 85 | else: 86 | self._vllm_port = 8000 87 | 88 | self.vllm_base_url = f"http://localhost:{self._vllm_port}" 89 | 90 | # Polling to check if the server is up 91 | server_up = self._check_server_health() 92 | 93 | if not server_up: 94 | raise RuntimeError("Server failed to start within the maximum allowed time.") 95 | 96 | def _check_server_health(self): 97 | """Check if the server is up and running.""" 98 | start_time = time.time() 99 | while time.time() - start_time < self.MAX_FAILED_SECONDS: 100 | try: 101 | response = httpx.get(f"{self.vllm_base_url}/health") 102 | logger.info(f"Checking server health: {response.status_code}") 103 | if response.status_code == 200: 104 | return True 105 | except httpx.RequestError as e: 106 | seconds_passed = int(time.time() - start_time) 107 | if seconds_passed % 10 == 0: 108 | logger.info(f"Server is starting for {seconds_passed} seconds: {e}") 109 | time.sleep(1) # Wait for 1 second before retrying 110 | return False 111 | 112 | def _load_non_openai_compatible_model(self): 113 | """Load non-OpenAI compatible model.""" 114 | try: 115 | result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True) 116 | logger.info(result.stdout) 117 | except subprocess.CalledProcessError as e: 118 | logger.error(f"Command failed with code {e.returncode}: {e.stderr}") 119 | 120 | self.model_args = AsyncEngineArgs(model=self._model_repo_id, **self._vllm_config) 121 | self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args=self.model_args) 122 | self.tokenizer = AutoTokenizer.from_pretrained(self._model_repo_id) 123 | 124 | try: 125 | result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True) 126 | logger.info(result.stdout) 127 | except subprocess.CalledProcessError as e: 128 | logger.error(f"Command failed with code {e.returncode}: {e.stderr}") 129 | 130 | def load(self): 131 | """Load the model.""" 132 | self._model_metadata = self._config["model_metadata"] 133 | self._model_repo_id = self._model_metadata["repo_id"] 134 | self._vllm_config = self._model_metadata["vllm_config"] 135 | if self._vllm_config is None: 136 | self._vllm_config = {} 137 | logger.info(f"main model: {self._model_repo_id}") 138 | logger.info(f"vllm config: {self._vllm_config}") 139 | 140 | if self.openai_compatible: 141 | self._load_openai_compatible_model() 142 | else: 143 | self._load_non_openai_compatible_model() 144 | 145 | try: 146 | run_background_vllm_health_check( 147 | self.openai_compatible, 148 | self.HEALTH_CHECK_INTERVAL, 149 | self.llm_engine, 150 | self.vllm_base_url, 151 | ) 152 | except Exception as e: 153 | raise RuntimeError(f"Failed to start background health check: {e}") from e 154 | 155 | async def predict(self, model_input): 156 | """Generate output based on the input.""" 157 | if "messages" not in model_input and "prompt" not in model_input: 158 | raise ValueError("Prompt or messages must be provided") 159 | 160 | stream = model_input.get("stream", False) 161 | 162 | if self.openai_compatible: 163 | return await self._predict_openai_compatible(model_input, stream) 164 | else: 165 | return await self._predict_non_openai_compatible(model_input, stream) 166 | 167 | async def _predict_openai_compatible(self, model_input, stream): 168 | """Generate output for OpenAI compatible model.""" 169 | # if the key metrics: true is present, let's return the vLLM /metrics endpoint 170 | if model_input.get("metrics", False): 171 | response = await self._client.get(f"{self.vllm_base_url}/metrics") 172 | return response.text 173 | 174 | # convenience for Baseten bridge 175 | if "model" not in model_input and self._model_repo_id: 176 | logger.info( 177 | f"model_input missing model due to Baseten bridge, using {self._model_repo_id}" 178 | ) 179 | model_input["model"] = self._model_repo_id 180 | 181 | if stream: 182 | 183 | async def generator(): 184 | async with self._client.stream( 185 | "POST", 186 | f"{self.vllm_base_url}/v1/chat/completions", 187 | json=model_input, 188 | ) as response: 189 | async for chunk in response.aiter_bytes(): 190 | if chunk: 191 | yield chunk 192 | 193 | return generator() 194 | else: 195 | response = await self._client.post( 196 | f"{self.vllm_base_url}/v1/chat/completions", 197 | json=model_input, 198 | ) 199 | 200 | return response.json() 201 | 202 | async def _predict_non_openai_compatible(self, model_input, stream): 203 | """Generate output for non-OpenAI compatible model.""" 204 | # SamplingParams does not take/use argument 'model' 205 | if "model" in model_input: 206 | model_input.pop("model") 207 | if "prompt" in model_input: 208 | prompt = model_input.pop("prompt") 209 | sampling_params = SamplingParams(**model_input) 210 | idx = str(uuid.uuid4().hex) 211 | messages = [ 212 | {"role": "user", "content": prompt}, 213 | ] 214 | # templatize the input to the model 215 | input = self.tokenizer.apply_chat_template( 216 | messages, tokenize=False, add_generation_prompt=True 217 | ) 218 | elif "messages" in model_input: 219 | messages = model_input.pop("messages") 220 | sampling_params = SamplingParams(**model_input) 221 | idx = str(uuid.uuid4().hex) 222 | # templatize the input to the model 223 | input = self.tokenizer.apply_chat_template( 224 | messages, 225 | tokenize=False, 226 | ) 227 | logger.info(f"Using SamplingParams: {sampling_params}") 228 | # since we accept any valid vllm sampling parameters, we can just pass it through 229 | vllm_generator = self.llm_engine.generate(input, sampling_params, idx) 230 | 231 | async def generator(): 232 | full_text = "" 233 | async for output in vllm_generator: 234 | text = output.outputs[0].text 235 | delta = text[len(full_text) :] 236 | full_text = text 237 | yield delta 238 | 239 | if stream: 240 | return generator() 241 | else: 242 | full_text = "" 243 | async for delta in generator(): 244 | full_text += delta 245 | return {"text": full_text} 246 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/errors.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from pydantic import BaseModel, Field, field_validator 4 | 5 | 6 | class FlowJudgeError(BaseModel): 7 | """Represents an error encountered during the Flow Judge evaluation process. 8 | 9 | This class encapsulates detailed error information, including the type of error, 10 | the specific message, the request ID that caused the error, and other metadata. 11 | 12 | Attributes: 13 | error_type (str): The type of error encountered (e.g., "TimeoutError"). 14 | error_message (str): A detailed description of the error. 15 | request_id (str): The ID of the request that caused the error. 16 | timestamp (datetime): The time when the error occurred. 17 | retry_count (int): The number of retry attempts made before the error was raised. 18 | raw_response (Optional[str]): The raw response from Baseten or proxy, if available. 19 | 20 | Note: 21 | This class is used for both logging and error handling. Ensure that sensitive 22 | information is not included in the error_message or raw_response fields. 23 | """ 24 | 25 | error_type: str = Field(..., description="Type of the error encountered") 26 | error_message: str = Field(..., description="Detailed error message") 27 | request_id: str | None = Field( 28 | default=None, description="ID of the request that caused the error" 29 | ) 30 | timestamp: datetime = Field( 31 | default_factory=datetime.now, description="Time when the error occurred" 32 | ) 33 | retry_count: int = Field(default=0, description="Number of retry attempts made") 34 | raw_response: str | None = Field( 35 | None, description="Raw response from Baseten or proxy, if available" 36 | ) 37 | 38 | @field_validator("error_type", "error_message") 39 | @classmethod 40 | def check_non_empty_string(cls, v): 41 | """Placeholder.""" 42 | if not v.strip(): 43 | raise ValueError("Field must not be empty or just whitespace") 44 | return v 45 | 46 | 47 | class BasetenAPIError(Exception): 48 | """Base exception for Baseten API errors.""" 49 | 50 | pass 51 | 52 | 53 | class BasetenRequestError(BasetenAPIError): 54 | """Exception for request-related errors.""" 55 | 56 | pass 57 | 58 | 59 | class BasetenResponseError(BasetenAPIError): 60 | """Exception for response-related errors.""" 61 | 62 | pass 63 | 64 | 65 | class BasetenRateLimitError(BasetenAPIError): 66 | """Exception for rate limit errors.""" 67 | 68 | pass 69 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/gpu.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from enum import Enum 4 | from typing import Any 5 | 6 | import yaml 7 | 8 | from .util import is_interactive 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class ModelByGPU(Enum): 14 | """Select the appropriate model based on GPU.""" 15 | 16 | H100_40GB = "flowaicom/Flow-Judge-v0.1-FP8" 17 | A10G = "flowaicom/Flow-Judge-v0.1-AWQ" 18 | 19 | 20 | def _get_gpu_key() -> str | None: 21 | """Fetches the BASETEN_GPU environment variable. 22 | 23 | :returns: The value of the variable: one of H100, A10G 24 | :rtype: str | None 25 | """ 26 | gpu: str | None = os.environ.get("BASETEN_GPU") 27 | 28 | if gpu: 29 | if gpu.lower() not in ["h100", "a10g"]: 30 | raise ValueError("BASETEN_GPU option is incorrect." "Possible options: H100, A10G") 31 | 32 | gpu = "H100_40GB" if gpu.lower() == "h100" else gpu 33 | 34 | return gpu 35 | 36 | 37 | def _has_gpu_key() -> bool: 38 | """Verifies that the BASETEN_GPU environment variable is present. 39 | 40 | :returns: True if found, False otherwise. 41 | :rtype: bool 42 | """ 43 | gpu: str | None = _get_gpu_key() 44 | 45 | return True if gpu else False 46 | 47 | 48 | def _update_config() -> bool: 49 | """Update the config.yaml file with the GPU & flow-judge model id. 50 | 51 | :returns: True if successfully updated, False if not. 52 | :rtype: bool 53 | """ 54 | gpu: str = _get_gpu_key() 55 | 56 | if not gpu: 57 | return False 58 | 59 | config_path: str = os.path.join( 60 | os.path.dirname(os.path.realpath(os.path.abspath(__file__))), "deployment", "config.yaml" 61 | ) 62 | 63 | try: 64 | with open(config_path) as file: 65 | data: dict[str, Any] = yaml.safe_load(file) 66 | 67 | data["resources"]["accelerator"] = ModelByGPU[gpu.upper()].name 68 | data["repo_id"] = ModelByGPU[gpu.upper()].value 69 | data["model_metadata"]["repo_id"] = ModelByGPU[gpu.upper()].value 70 | 71 | with open(config_path, "w") as file: 72 | yaml.safe_dump(data, file, default_flow_style=False) 73 | 74 | return True 75 | 76 | except FileNotFoundError: 77 | logger.error(f"Baseten config.yaml file not found on path {config_path}") 78 | return False 79 | except yaml.YAMLError as e: 80 | logger.error(f"Error: Failed to parse the Baseten config file. {e}") 81 | return False 82 | except Exception as e: 83 | logger.error(f"An unexpected error occurred with Baseten config file update: {e}") 84 | return False 85 | 86 | 87 | def ensure_gpu() -> bool: 88 | """Enable GPU selection for FlowJudge model deployment. 89 | 90 | :return: True if successfully updated, False otherwise 91 | :rtype: bool 92 | """ 93 | if _has_gpu_key(): 94 | return _update_config() 95 | 96 | if is_interactive(): 97 | print("What GPU on Baseten should we deploy the FlowJudge model to?") 98 | print(" ➡️ H100") 99 | print(" ➡️ A10G: default") 100 | print("Would you like to switch your deployment to H100?") 101 | print("y/n?\n") 102 | 103 | else: 104 | logger.info("Non-interactive environment detected") 105 | print("What GPU on Baseten should we deploy the FlowJudge model to?") 106 | print(" ➡️ H100") 107 | print(" ➡️ A10G") 108 | print("Please set the environment variable to the appropriate value:") 109 | print("```") 110 | print('os.environ["BASETEN_GPU"] = "<>"') 111 | print("```") 112 | return False 113 | 114 | logger.info("Prompting for GPU in interactive environment") 115 | while True: 116 | upgrade: str = input() 117 | if not upgrade: 118 | logger.warning("Empty option entered") 119 | print("Input is empty, try again.") 120 | continue 121 | 122 | if upgrade.lower() not in ["yes", "y", "n", "no"]: 123 | logger.warning("Incorrect option selected") 124 | print("Incorrect option, select one from y/n") 125 | continue 126 | 127 | if upgrade.lower() in ["yes", "y"]: 128 | os.environ["BASETEN_GPU"] = ModelByGPU.H100_40GB.name 129 | return _update_config() 130 | 131 | if upgrade.lower() in ["n", "no"]: 132 | os.environ["BASETEN_GPU"] = ModelByGPU.A10G.name 133 | return _update_config() 134 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/management.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import aiohttp 4 | import requests 5 | import structlog 6 | 7 | logger = structlog.get_logger(__name__) 8 | 9 | 10 | def _get_management_base_url(model_id: str) -> str: 11 | """Get the base URL for the Management API. 12 | 13 | :param model_id: The ID of the deployed model. 14 | :returns: The URL for the management endpoints 15 | """ 16 | return f"https://api.baseten.co/v1/models/{model_id}/deployments/production" 17 | 18 | 19 | def sync_set_scale_down(scale_down_delay: int, api_key: str, model_id: str) -> bool: 20 | """Syncronous update to the cooldown period for the deployed model. 21 | 22 | :param scale_down_delay: The cooldown period in seconds. 23 | :param api_key: The Baseten API key. 24 | :param model_id: The ID of the deployed model on Baseten. 25 | :returns bool: True if successful, False otherwise. 26 | """ 27 | try: 28 | resp = requests.patch( 29 | f"{_get_management_base_url(model_id)}/autoscaling_settings", 30 | headers={"Authorization": f"Api-Key {api_key}"}, 31 | json={"scale_down_delay": scale_down_delay}, 32 | ) 33 | 34 | re = resp.json() 35 | if "status" in re: 36 | return True if re["status"] in ["ACCEPTED", "UNCHANGED", "QUEUED"] else False 37 | except Exception as e: 38 | logger.error("Unexpected error occurred with Baseten scale down delay request" f" {e}") 39 | return False 40 | 41 | 42 | async def set_scale_down_delay(scale_down_delay: int, api_key: str, model_id: str) -> bool: 43 | """Dynamically updates the cooldown period for the deployed model. 44 | 45 | :param scale_down_delay: The cooldown period in seconds. 46 | :param api_key: The Baseten API key. 47 | :param model_id: The ID of the deployed model on Baseten. 48 | :returns bool: True if successful, False otherwise. 49 | """ 50 | url = f"{_get_management_base_url(model_id)}/autoscaling_settings" 51 | try: 52 | async with aiohttp.ClientSession() as session: 53 | async with session.patch( 54 | url=url, 55 | headers={"Authorization": f"Api-Key {api_key}"}, 56 | json={"scale_down_delay": scale_down_delay}, 57 | ) as response: 58 | if response.status != 200: 59 | logger.warning( 60 | "Unable to update Baseten scale down delay attribute." 61 | f"Request failed with status code {response.status}" 62 | ) 63 | return False 64 | 65 | resp = await response.json() 66 | if "status" in resp: 67 | return True if resp["status"] in ["ACCEPTED", "UNCHANGED", "QUEUED"] else False 68 | except aiohttp.ClientError as e: 69 | logger.warning("Network error with Baseten scale_down_delay" f" {e}") 70 | return False 71 | except Exception as e: 72 | logger.warning("Unexpected error occurred with Baseten scale down delay request" f" {e}") 73 | return False 74 | 75 | 76 | async def wake_deployment(model_id: str, api_key: str) -> bool: 77 | """Activates the Baseten model. 78 | 79 | :param model_id: The ID of the deployed model. 80 | :returns: True if success, False if failed. 81 | :rtype: bool 82 | """ 83 | url = f"https://model-{model_id}.api.baseten.co/production/wake" 84 | try: 85 | async with aiohttp.ClientSession() as session: 86 | async with session.post( 87 | url=url, headers={"Authorization": f"Api-Key {api_key}"}, json={} 88 | ) as response: 89 | if response.status != 202: 90 | logger.warning( 91 | "Unable to activate Baseten model." 92 | f"Request failed with status code {response.status}" 93 | ) 94 | return False 95 | 96 | return True 97 | except aiohttp.ClientError as e: 98 | logger.warning("Network error with Baseten model activation." f" {e}") 99 | return False 100 | except Exception as e: 101 | logger.error("Unexpected error occurred with Baseten model activation." f" {e}") 102 | return False 103 | 104 | 105 | async def get_production_deployment_status(model_id: str, api_key: str) -> str | None: 106 | """Get model production deployment_id by it's model_id. 107 | 108 | :param model_id: The ID of the deployed model. 109 | :returns: The deployment_id of the production model. 110 | :rtype: str 111 | """ 112 | try: 113 | async with aiohttp.ClientSession() as session: 114 | async with session.get( 115 | _get_management_base_url(model_id), 116 | headers={"Authorization": f"Api-Key {api_key}"}, 117 | ) as response: 118 | if response.status != 200: 119 | logger.warning( 120 | "Unable to get model deployment details" 121 | f"Request failed with status {response.status}" 122 | ) 123 | return None 124 | 125 | re = await response.json() 126 | return re["status"] 127 | except (json.JSONDecodeError, KeyError, IndexError) as e: 128 | logger.warning("Unable to parse response for Model deployment info request." f" {e}") 129 | return None 130 | except aiohttp.ClientError as e: 131 | logger.warning("Network error with Baseten model deployment information." f" {e}") 132 | return None 133 | except Exception as e: 134 | logger.error( 135 | "Unexpected error occurred with Baseten model deployment info request." f" {e}" 136 | ) 137 | return None 138 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/token_bucket.py: -------------------------------------------------------------------------------- 1 | import time 2 | from dataclasses import dataclass, field 3 | 4 | 5 | @dataclass 6 | class TokenBucket: 7 | """Implements a token bucket algorithm for rate limiting. 8 | 9 | This class manages a token bucket with a specified capacity and fill rate, 10 | allowing for controlled consumption of tokens over time. 11 | 12 | Attributes: 13 | tokens (float): Current number of tokens in the bucket. 14 | fill_rate (float): Rate at which tokens are added to the bucket (tokens per second). 15 | capacity (float): Maximum number of tokens the bucket can hold. 16 | last_update (float): Timestamp of the last token update. 17 | 18 | Note: 19 | This implementation is not thread-safe. If used in a multi-threaded environment, 20 | external synchronization mechanisms should be applied. 21 | """ 22 | 23 | tokens: float 24 | fill_rate: float 25 | capacity: float 26 | last_update: float = field(default_factory=time.time) 27 | 28 | def consume(self, tokens: int = 1) -> bool: 29 | """Attempt to consume tokens from the bucket. 30 | 31 | Args: 32 | tokens (int): Number of tokens to consume. Defaults to 1. 33 | 34 | Returns: 35 | bool: True if tokens were successfully consumed, False otherwise. 36 | 37 | Note: 38 | This method updates the token count based on the time elapsed since 39 | the last update, then attempts to consume the requested number of tokens. 40 | """ 41 | now = time.time() 42 | self.tokens = min(self.capacity, self.tokens + self.fill_rate * (now - self.last_update)) 43 | self.last_update = now 44 | if self.tokens >= tokens: 45 | self.tokens -= tokens 46 | return True 47 | return False 48 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/util.py: -------------------------------------------------------------------------------- 1 | def is_interactive() -> bool: 2 | """Check if the current environment is interactive. 3 | 4 | :return: True if the environment is interactive, False otherwise. 5 | :rtype: bool 6 | """ 7 | import sys 8 | 9 | return sys.__stdin__.isatty() 10 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/validation.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import hmac 3 | import logging 4 | import os 5 | from datetime import datetime, timezone 6 | 7 | from pydantic import BaseModel, ConfigDict, JsonValue 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | TIMESTAMP_TOLERANCE_SECONDS = 300 12 | 13 | 14 | class AsyncPredictResult(BaseModel): 15 | """Baseten completion response format.""" 16 | 17 | model_config = ConfigDict(protected_namespaces=(), extra="allow") 18 | 19 | request_id: str 20 | model_id: str 21 | deployment_id: str 22 | type: str 23 | time: datetime 24 | data: JsonValue 25 | errors: list[dict] 26 | 27 | 28 | def validate_baseten_signature(result, actual_signatures) -> bool: 29 | """Webhook signature validation from baseten.""" 30 | try: 31 | webhook_secret = os.environ["BASETEN_WEBHOOK_SECRET"] 32 | except Exception as e: 33 | logger.error( 34 | "Baseten webhook secret is not in the environment." 35 | "Unable to validate baseten signature for batched requests." 36 | "Set the BASETEN_WEBHOOK_SECRET env variable to proceed." 37 | f"{e}" 38 | ) 39 | return False 40 | 41 | async_predict_result = AsyncPredictResult(**result) 42 | 43 | if ( 44 | datetime.now(timezone.utc) - async_predict_result.time 45 | ).total_seconds() > TIMESTAMP_TOLERANCE_SECONDS: 46 | logger.error( 47 | f"Async predict result was received after {TIMESTAMP_TOLERANCE_SECONDS} seconds" 48 | "and is considered stale, Baseten signature was not validated." 49 | ) 50 | return False 51 | 52 | for actual_signature in actual_signatures.replace("v1=", "").split(","): 53 | expected_signature = hmac.digest( 54 | webhook_secret.encode("utf-8"), 55 | async_predict_result.model_dump_json().encode("utf-8"), 56 | hashlib.sha256, 57 | ).hex() 58 | 59 | if hmac.compare_digest(expected_signature, actual_signature): 60 | logger.info("Baseten signature is valid!") 61 | return True 62 | 63 | logger.error( 64 | "Baseten signature is not valid. Ensure your webhook secrets are properly configured." 65 | ) 66 | return False 67 | -------------------------------------------------------------------------------- /flow_judge/models/adapters/baseten/webhook.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import logging 3 | import os 4 | import re 5 | 6 | from .util import is_interactive 7 | 8 | logger: logging.Logger = logging.getLogger(__name__) 9 | 10 | _stored_secret_path: str = "~/.config/flow-judge/baseten_webhook_secret" 11 | _stored_skip_file: str = "~/.config/flow-judge/baseten_whsec_skipped" 12 | 13 | 14 | def _is_valid_secret(secret: str) -> bool: 15 | """Validate the Baseten webhook secret. 16 | 17 | :param secret: The secret to validate 18 | :return: True if the secret is valid, False otherwise 19 | """ 20 | logger.debug("Validating webhook secret") 21 | return bool(re.match(r"^whsec_[a-zA-Z0-9]{40}$", secret)) 22 | 23 | 24 | def _save_webhook_secret(secret: str) -> None: 25 | """Save the webhook secret to a file. 26 | 27 | :param secret: The secret to save 28 | """ 29 | logger.debug(f"Saving webhook secret to {_stored_secret_path}") 30 | try: 31 | p: str = os.path.expanduser(_stored_secret_path) 32 | os.makedirs(os.path.dirname(p), exist_ok=True) 33 | with open(p, "w") as f: 34 | f.write(secret) 35 | logger.info("Webhook secret saved successfully") 36 | except OSError as e: 37 | logger.error(f"Failed to save webhook secret: {e}") 38 | 39 | 40 | def _get_stored_secret() -> str | None: 41 | """Retrieve the stored webhook secret. 42 | 43 | :return: The stored secret if available, None otherwise 44 | """ 45 | logger.debug(f"Attempting to retrieve stored webhook secret from {_stored_secret_path}") 46 | try: 47 | p: str = os.path.expanduser(_stored_secret_path) 48 | if not os.path.exists(p): 49 | logger.info("No stored webhook secret found") 50 | return None 51 | with open(p) as f: 52 | secret: str = f.read().strip() 53 | logger.info("Stored webhook secret retrieved successfully") 54 | return secret 55 | except OSError as e: 56 | logger.error(f"Failed to retrieve stored webhook secret: {e}") 57 | return None 58 | 59 | 60 | def _save_skip_file() -> None: 61 | """Creates a flag file indicating that user requested to skip secret webhook input.""" 62 | os.makedirs(os.path.expanduser(os.path.dirname(_stored_skip_file)), exist_ok=True) 63 | open(os.path.expanduser(_stored_skip_file), "a").close() 64 | 65 | 66 | def _handle_skip() -> bool: 67 | """Handles a possible user request to skip question about webhook secret. 68 | 69 | Checks for existence of BASETEN_SKIP_WEBHOOK_SECRET env variable, indicating that 70 | skip file should be saved. Skip file existence is a flag telling us to skip the 71 | question about webhook secret. 72 | 73 | :return: True if the webhook secret prompt should be skipped, False otherwise 74 | :rtype: bool 75 | """ 76 | if os.environ.get("BASETEN_SKIP_WEBHOOK_SECRET"): 77 | logger.info( 78 | "User requested to skip the webhook secret prompt, skipping and saving skip file. " 79 | "Remove ~/.config/flow-judge/baseten_whsec_skipped to restore the prompt." 80 | ) 81 | _save_skip_file() 82 | return True 83 | 84 | if os.path.exists(os.path.expanduser(_stored_skip_file)): 85 | logger.info( 86 | "Webhook secret not required and user skipped it before, skipping. " 87 | "Remove ~/.config/flow-judge/baseten_whsec_skipped to restore the prompt." 88 | ) 89 | return True 90 | 91 | return False 92 | 93 | 94 | def _handle_env_variable_input() -> bool: 95 | """Checks if webhook secret was provided in environment variable. 96 | 97 | :return: True if the webhook secret was provided, False otherwise 98 | :rtype: bool 99 | """ 100 | env_secret: str | None = os.getenv("BASETEN_WEBHOOK_SECRET") 101 | if not env_secret: 102 | return False 103 | 104 | logger.info("Found BASETEN_WEBHOOK_SECRET in environment variables") 105 | if _is_valid_secret(env_secret): 106 | logger.info("Environment variable contains a valid webhook secret") 107 | _save_webhook_secret(env_secret) 108 | else: 109 | logger.warning("Probably invalid BASETEN_WEBHOOK_SECRET in environment variable") 110 | return True 111 | 112 | 113 | def _handle_stored_secret() -> bool: 114 | """Checks if webhook secret was previously provided and stored in file. 115 | 116 | :return: True if the webhook secret was provided, False otherwise 117 | :rtype: bool 118 | """ 119 | stored_secret: str | None = _get_stored_secret() 120 | if not stored_secret: 121 | return False 122 | 123 | logger.info("Found stored webhook secret") 124 | if _is_valid_secret(stored_secret): 125 | logger.info("Stored webhook secret is valid") 126 | os.environ["BASETEN_WEBHOOK_SECRET"] = stored_secret 127 | else: 128 | logger.warning("Stored webhook secret is probably invalid") 129 | return True 130 | 131 | 132 | def _prompt_interactively(optional: bool) -> str | None: 133 | """Asks user to enter webhook secret or skip if optional. 134 | 135 | :param optional: True if the input can be skipped by user 136 | :return: User-provided webhook secret or None if skipped 137 | :rtype: str | None 138 | """ 139 | while True: 140 | secret: str = getpass.getpass( 141 | "Baseten webhook secret (hidden): " 142 | if not optional 143 | else "Baseten webhook secret (hidden; leave empty to skip): " 144 | ) 145 | 146 | if optional and not secret: 147 | logger.info("User skipped optional webhook secret input") 148 | return None 149 | 150 | if not secret: 151 | logger.warning("Empty input received") 152 | print("Input is empty, please try again.") 153 | continue 154 | 155 | return secret 156 | 157 | 158 | def _print_general_prompt(optional: bool) -> None: 159 | """Prints general information/prompt about webhook secret input requirement. 160 | 161 | :param optional: Whether to print the information that the input is optional 162 | """ 163 | print( 164 | "To run Flow Judge remotely with Baseten and enable async execution, " 165 | "you need to create and configure a webhook secret.\n" 166 | "➡️ Creating a webhook secret: https://docs.baseten.co/invoke/async-secure\n\n" 167 | "The webhook secret is used to validate that the webhook responses originated " 168 | "from Baseten. " 169 | f"It will be stored in {_stored_secret_path} for later use.\n" 170 | "For your convenience, Baseten responses are forwarded to you using Flow AI proxy.\n" 171 | "Explore what this means and the alternatives here: " 172 | "https://github.com/flowaicom/flow-judge/\n" 173 | ) 174 | if optional: 175 | print("\033[1mOptional. This is only required if you plan async execution.\033[0m") 176 | 177 | 178 | def _print_noninteractive_prompt(optional: bool) -> None: 179 | """Prints a prompt for non interactive environments. 180 | 181 | :param optional: Whether to print the information that the input is optional 182 | """ 183 | logger.info("Non-interactive environment detected") 184 | print( 185 | "Set the Baseten webhook secret in the BASETEN_WEBHOOK_SECRET " 186 | "environment variable and run again:\n" 187 | 'os.environ["BASETEN_WEBHOOK_SECRET"] = "«your webhook secret»"\n' 188 | "The secret should start with 'whsec_' followed by 40 alphanumeric characters.\n" 189 | ) 190 | if optional: 191 | print( 192 | "If you don't want to see this message in the future set the " 193 | "BASETEN_SKIP_WEBHOOK_SECRET environment variable to any non-empty value:\n" 194 | 'os.environ["BASETEN_SKIP_WEBHOOK_SECRET"]="true"\n' 195 | ) 196 | 197 | 198 | def ensure_baseten_webhook_secret(optional: bool = False) -> bool: 199 | """Ensure a Baseten webhook secret is available. 200 | 201 | This function checks for a webhook secret in the following order: 202 | 1. Environment variable 203 | 2. Stored secret file 204 | 3. User input (in interactive mode) 205 | 206 | :param optional: Whether allow the user to omit the webhook secret input 207 | :return: True if a secret is available (valid or not), False if no secret could be obtained 208 | """ 209 | # If input optional, check if user requested to skip the prompt 210 | if optional and _handle_skip(): 211 | return False 212 | 213 | logger.info("Ensuring Baseten webhook secret") 214 | 215 | # Check environment variable 216 | if _handle_env_variable_input(): 217 | return True 218 | 219 | # Check stored secret 220 | if _handle_stored_secret(): 221 | return True 222 | 223 | logger.info("No existing webhook secret found, prompting user") 224 | _print_general_prompt(optional) 225 | 226 | # In non-interactive environments we return early because their inputs (env vars) 227 | # are processed at the very beginning of this function. 228 | if not is_interactive(): 229 | _print_noninteractive_prompt(optional) 230 | return False 231 | 232 | logger.info("Prompting user for webhook secret interactively") 233 | secret = _prompt_interactively(optional) 234 | 235 | # Secret not provided which means user asked to skip the prompt 236 | if secret is None: 237 | print("(skipped)") 238 | _save_skip_file() 239 | return False 240 | 241 | if not _is_valid_secret(secret): 242 | logger.warning("Invalid webhook secret provided") 243 | print( 244 | "Warning: The provided webhook secret is probably invalid. " 245 | "It should start with 'whsec_' followed by 40 alphanumeric characters. " 246 | "Proceeding anyway." 247 | ) 248 | 249 | _save_webhook_secret(secret) 250 | os.environ["BASETEN_WEBHOOK_SECRET"] = secret 251 | 252 | return True 253 | -------------------------------------------------------------------------------- /flow_judge/models/baseten.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections.abc import Coroutine 3 | from typing import Any 4 | 5 | from .adapters.baseten.adapter import AsyncBasetenAPIAdapter, BaseAPIAdapter, BasetenAPIAdapter 6 | from .adapters.baseten.data_io import BatchResult 7 | from .adapters.baseten.deploy import ensure_model_deployment, get_deployed_model_id 8 | from .adapters.baseten.webhook import ensure_baseten_webhook_secret 9 | from .common import ( 10 | AsyncBaseFlowJudgeModel, 11 | BaseFlowJudgeModel, 12 | ModelConfig, 13 | ModelType, 14 | VllmGenerationParams, 15 | ) 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class BasetenModelConfig(ModelConfig): 21 | """Model config for the Baseten model class.""" 22 | 23 | def __init__( 24 | self, 25 | generation_params: VllmGenerationParams, 26 | exec_async: bool = False, 27 | webhook_proxy_url: str | None = None, 28 | async_batch_size: int = 128, 29 | **kwargs: Any, 30 | ): 31 | """Initialize the Baseten model config. 32 | 33 | :param generation_params: VllmGenerationParams for text generation. 34 | :param exec_async: Whether to use async execution. 35 | :param webhook_proxy_url: Webhook URL for Baseten async execution. 36 | :param async_batch_size: Batch size for concurrent requests in async mode. 37 | :raises ValueError: If any input parameters are invalid. 38 | """ 39 | model_id = kwargs.pop("_model_id", None) 40 | if model_id is None: 41 | model_id = get_deployed_model_id() 42 | if model_id is None: 43 | raise ValueError("Unable to retrieve Baseten's deployed model id.") 44 | 45 | model_type = ModelType.BASETEN_VLLM_ASYNC if exec_async else ModelType.BASETEN_VLLM 46 | 47 | if not isinstance(generation_params, VllmGenerationParams): 48 | raise ValueError("generation_params must be an instance of VllmGenerationParams") 49 | if async_batch_size <= 0: 50 | raise ValueError(f"async_batch_size must be > 0, got {async_batch_size}") 51 | if exec_async and webhook_proxy_url is None: 52 | raise ValueError("webhook_proxy_url is required for async execution") 53 | 54 | super().__init__(model_id, model_type, generation_params, **kwargs) 55 | self.webhook_proxy_url = webhook_proxy_url 56 | self.exec_async = exec_async 57 | self.async_batch_size = async_batch_size 58 | 59 | 60 | class Baseten(BaseFlowJudgeModel, AsyncBaseFlowJudgeModel): 61 | """Combined FlowJudge Model class for Baseten sync and webhook async operations.""" 62 | 63 | def __init__( 64 | self, 65 | api_adapter: BaseAPIAdapter | None = None, 66 | webhook_proxy_url: str | None = None, 67 | exec_async: bool = False, 68 | async_batch_size: int = 128, 69 | generation_params: dict[str, Any] | None = None, 70 | **kwargs: Any, 71 | ): 72 | """Initialize the Baseten Model class. 73 | 74 | :param api_adapter: API handling class for Baseten requests. 75 | :param webhook_proxy_url: The webhook url for the proxy when exec_async is True. 76 | :param exec_async: Whether to use async webhook execution. 77 | :param async_batch_size: Batch size for concurrent requests to Baseten in async. 78 | :param generation_params: Dictionary of parameters for text generation. 79 | :raises BasetenError: If Baseten deployment or model ID retrieval fails. 80 | :raises ValueError: If input parameters are invalid. 81 | """ 82 | if exec_async and not webhook_proxy_url: 83 | raise ValueError("webhook_proxy_url is required for async Baseten execution.") 84 | 85 | if async_batch_size < 1: 86 | raise ValueError("async_batch_size must be greater than 0.") 87 | 88 | # Check for a custom Baseten model ID provided by the user 89 | # This allows for using a specific model deployment instead of the default 90 | model_id = kwargs.pop("_model_id", None) 91 | if model_id is None: 92 | # If no custom ID, attempt to retrieve the default deployed Flow Judge model ID 93 | # This covers the case where the model is already deployed but not specified 94 | model_id = get_deployed_model_id() 95 | if model_id is None: 96 | # No deployed model found, so try to deploy the default Flow Judge model 97 | # This step handles first-time usage or cases where the model was removed 98 | if not ensure_model_deployment(): 99 | # Deployment failed, which could be due to API key issues, 100 | # network problems, or Baseten service unavailability 101 | raise BasetenError( 102 | status_code=1, 103 | message=( 104 | "Baseten deployment is not available." 105 | " This could be due to API key issues, " 106 | "network problems, or Baseten service " 107 | "unavailability. Please check your " 108 | "API key, network connection, and Baseten service status." 109 | ), 110 | ) 111 | # Deployment succeeded, so attempt to retrieve the model ID again 112 | # This should now succeed unless there's an unexpected issue 113 | model_id = get_deployed_model_id() 114 | if model_id is None: 115 | # If we still can't get the model ID, it indicates a deeper problem 116 | # This could be due to a bug in the deployment process or Baseten API changes 117 | raise BasetenError( 118 | status_code=2, 119 | message=( 120 | "Unable to retrieve Baseten's deployed model id. " 121 | "Please ensure the model is deployed or provide a custom '_model_id'." 122 | ), 123 | ) 124 | 125 | if exec_async and not ensure_baseten_webhook_secret(): 126 | raise BasetenError( 127 | status_code=4, 128 | message=( 129 | "Unable to retrieve Baseten's webhook secret. " 130 | "Please ensure the webhook secret is provided in " 131 | "BASETEN_WEBHOOK_SECRET environment variable." 132 | ), 133 | ) 134 | 135 | if api_adapter is not None and not isinstance( 136 | api_adapter, (BasetenAPIAdapter, AsyncBasetenAPIAdapter) 137 | ): 138 | raise BasetenError( 139 | status_code=3, 140 | message="Incompatible API adapter. Use BasetenAPIAdapter or AsyncBasetenAPIAdapter", 141 | ) 142 | 143 | self.api_adapter = api_adapter or ( 144 | AsyncBasetenAPIAdapter(model_id, webhook_proxy_url, async_batch_size) 145 | if exec_async 146 | else BasetenAPIAdapter(model_id) 147 | ) 148 | 149 | generation_params = VllmGenerationParams(**(generation_params or {})) 150 | config = BasetenModelConfig( 151 | generation_params=generation_params, 152 | exec_async=exec_async, 153 | webhook_proxy_url=webhook_proxy_url, 154 | async_batch_size=async_batch_size, 155 | _model_id=model_id, 156 | ) 157 | self.config = config 158 | 159 | super().__init__(model_id, config.model_type, config.generation_params, **kwargs) 160 | 161 | logger.info("Successfully initialized Baseten!") 162 | 163 | def _format_conversation(self, prompt: str) -> list[dict[str, Any]]: 164 | return [{"role": "user", "content": prompt.strip()}] 165 | 166 | def _generate(self, prompt: str) -> str: 167 | logger.info("Initiating single Baseten request") 168 | 169 | conversation = self._format_conversation(prompt) 170 | return self.api_adapter._fetch_response(conversation) 171 | 172 | def _batch_generate( 173 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any 174 | ) -> list[str]: 175 | logger.info("Initiating batched Baseten requests") 176 | 177 | conversations = [self._format_conversation(prompt) for prompt in prompts] 178 | return self.api_adapter._fetch_batched_response(conversations) 179 | 180 | async def _async_generate(self, prompt: str) -> str: 181 | if self.config.exec_async: 182 | return await self.api_adapter._async_fetch_response(prompt.strip()) 183 | else: 184 | logger.error("Attempting to run an async request with a synchronous API adapter") 185 | 186 | async def _async_batch_generate( 187 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any 188 | ) -> Coroutine[Any, Any, BatchResult]: 189 | if self.config.exec_async: 190 | cleaned_prompts = [prompt.strip() for prompt in prompts] 191 | return await self.api_adapter._async_fetch_batched_response(cleaned_prompts) 192 | else: 193 | logger.error("Attempting to run an async request with a synchronous API adapter") 194 | 195 | 196 | class BasetenError(Exception): 197 | """Custom exception for Baseten-related errors.""" 198 | 199 | def __init__(self, status_code: int, message: str): 200 | """Initialize with a status code and message.""" 201 | self.status_code = status_code 202 | self.message = message 203 | super().__init__(self.message) 204 | -------------------------------------------------------------------------------- /flow_judge/models/common.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from enum import Enum 3 | from typing import Any 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | from .adapters.base import BaseAPIAdapter 8 | 9 | 10 | class BaseFlowJudgeModel(ABC): 11 | """Base class for all FlowJudge models.""" 12 | 13 | def __init__( 14 | self, model_id: str, model_type: str, generation_params: dict[str, Any], **kwargs: Any 15 | ) -> None: 16 | """Initialize the base FlowJudge model.""" 17 | self.metadata: dict[str, Any] = { 18 | "model_id": model_id, 19 | "model_type": model_type, 20 | "generation_params": generation_params, 21 | "kwargs": kwargs, 22 | } 23 | 24 | @abstractmethod 25 | def _generate(self, prompt: str) -> str: 26 | """Generate a response based on the given prompt.""" 27 | pass 28 | 29 | @abstractmethod 30 | def _batch_generate( 31 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any 32 | ) -> list[str]: 33 | """Generate responses for multiple prompts.""" 34 | pass 35 | 36 | 37 | class AsyncBaseFlowJudgeModel(ABC): 38 | """Base class for asynchronous FlowJudge models.""" 39 | 40 | def __init__( 41 | self, model_id: str, model_type: str, generation_params: dict[str, Any], **kwargs: Any 42 | ) -> None: 43 | """Initialize the base asynchronous FlowJudge model.""" 44 | self.metadata: dict[str, Any] = { 45 | "model_id": model_id, 46 | "model_type": model_type, 47 | "generation_params": generation_params, 48 | "kwargs": kwargs, 49 | } 50 | 51 | @abstractmethod 52 | async def _async_generate(self, prompt: str) -> str: 53 | """Generate a response based on the given prompt asynchronously.""" 54 | pass 55 | 56 | @abstractmethod 57 | async def _async_batch_generate( 58 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any 59 | ) -> list[str]: 60 | """Generate responses for multiple prompts asynchronously.""" 61 | pass 62 | 63 | 64 | class FlowJudgeRemoteModel(BaseFlowJudgeModel): 65 | """Flow judge model class for remote hosting.""" 66 | 67 | def __init__( 68 | self, 69 | model_id: str, 70 | model_type: str, 71 | generation_params: dict[str, Any], 72 | api_adapter: BaseAPIAdapter, 73 | **remote_kwargs: Any, 74 | ): 75 | """Initialize the FlowJudge remote model class. 76 | 77 | :param model_id: The ID of the model. 78 | :param model_type: Type of the model based on ModelType. 79 | :param generation_params: Relevant generation params for the model type. 80 | :param remote_kwargs: Keyword arguments to initialize the parameters. 81 | """ 82 | super().__init__(model_id, model_type, generation_params, **remote_kwargs) 83 | 84 | if not isinstance(api_adapter, BaseAPIAdapter): 85 | raise ValueError("Invalid Adapter type. Use BaseAPIAdapter.") 86 | 87 | self.api_adapter = api_adapter 88 | 89 | def generate(self, prompt: str) -> str: 90 | """Single generation request.""" 91 | conversation = [{"role": "user", "content": prompt.strip()}] 92 | return self.api_adapter.fetch_response(conversation) 93 | 94 | def batch_generate(self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any) -> list[str]: 95 | """Batched generation request.""" 96 | conversations = [[{"role": "user", "content": prompt.strip()}] for prompt in prompts] 97 | return self.api_adapter.fetch_batched_response(conversations) 98 | 99 | 100 | class GenerationParams(BaseModel): 101 | """Configuration parameters for text generation.""" 102 | 103 | temperature: float = Field(default=0.1, description="Sampling temperature") 104 | top_p: float = Field(default=0.95, description="Top-p sampling parameter") 105 | max_new_tokens: int = Field( 106 | default=1000, description="Maximum number of new tokens to generate" 107 | ) 108 | do_sample: bool = Field(default=True, description="Whether to use sampling for generation") 109 | 110 | 111 | class VllmGenerationParams(GenerationParams): 112 | """Configuration parameters specific to VLLM text generation.""" 113 | 114 | max_tokens: int | None = None 115 | stop_token_ids: list[int] = [32007, 32001, 32000] 116 | 117 | def __init__(self, **data): 118 | """Initialize VllmGenerationParams with given data. 119 | 120 | :param data: Keyword arguments to initialize the parameters. 121 | """ 122 | super().__init__(**data) 123 | self.max_tokens = self.max_new_tokens 124 | del self.max_new_tokens 125 | del self.do_sample 126 | 127 | 128 | class ModelType(Enum): 129 | """Enum for the type of model.""" 130 | 131 | TRANSFORMERS = "transformers" 132 | VLLM = "vllm" 133 | VLLM_ASYNC = "vllm_async" 134 | LLAMAFILE = "llamafile" 135 | BASETEN_VLLM = "baseten_vllm" 136 | BASETEN_VLLM_ASYNC = "baseten_vllm_async" 137 | 138 | 139 | class Engine(Enum): 140 | """Enum for the type of engine used for text generation.""" 141 | 142 | VLLM: str = "vllm" 143 | VLLM_ASYNC: str = "vllm_async" 144 | HF: str = "hf" # HF stands for Hugging Face (Transformers) 145 | LLAMAFILE: str = "llamafile" 146 | 147 | 148 | class ModelConfig: 149 | """Base configuration for a model.""" 150 | 151 | def __init__( 152 | self, 153 | model_id: str, 154 | model_type: ModelType, 155 | generation_params: dict[str, Any], 156 | **kwargs: Any, 157 | ) -> None: 158 | """Initialize ModelConfig with model details and generation parameters. 159 | 160 | :param model_id: Identifier for the model. 161 | :param model_type: Type of the model. 162 | :param generation_params: Parameters for text generation. 163 | :param kwargs: Additional keyword arguments. 164 | """ 165 | self.model_id: str = model_id 166 | self.model_type: ModelType = model_type 167 | self.generation_params: dict[str, Any] = generation_params 168 | self.kwargs: dict[str, Any] = kwargs 169 | -------------------------------------------------------------------------------- /flow_judge/models/huggingface.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import warnings 4 | from typing import Any 5 | 6 | from flow_judge.models.common import BaseFlowJudgeModel, GenerationParams, ModelConfig, ModelType 7 | 8 | try: 9 | import torch 10 | from huggingface_hub import snapshot_download 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | 13 | HF_AVAILABLE = True 14 | except ImportError: 15 | HF_AVAILABLE = False 16 | 17 | from tqdm import tqdm 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class HfConfig(ModelConfig): 23 | """Configuration class for Hugging Face models.""" 24 | 25 | _DEFAULT_MODEL_ID = "flowaicom/Flow-Judge-v0.1" 26 | 27 | def __init__( 28 | self, 29 | generation_params: GenerationParams, 30 | device_map: str = "auto", 31 | torch_dtype: str = "bfloat16", 32 | flash_attn: bool = True, 33 | **kwargs: Any, 34 | ): 35 | """Initialize HfConfig with model details and Hugging Face specific parameters. 36 | 37 | :param generation_params: Parameters for text generation. 38 | :param device_map: Device mapping strategy. 39 | :param torch_dtype: PyTorch data type for the model. 40 | :param flash_attn: Whether to use flash attention. 41 | :param kwargs: Additional keyword arguments. 42 | """ 43 | model_id = kwargs.pop("_model_id", self._DEFAULT_MODEL_ID) 44 | super().__init__(model_id, ModelType.TRANSFORMERS, generation_params.model_dump(), **kwargs) 45 | self.device_map = device_map 46 | self.torch_dtype = torch_dtype 47 | self.flash_attn = flash_attn 48 | self.kwargs = kwargs 49 | 50 | 51 | class Hf(BaseFlowJudgeModel): 52 | """FlowJudge model class for Hugging Face Transformers.""" 53 | 54 | _DEFAULT_MODEL_ID = "flowaicom/Flow-Judge-v0.1" 55 | 56 | def __init__( 57 | self, 58 | generation_params: dict[str, Any] | None = None, 59 | flash_attn: bool = True, 60 | **kwargs: Any, 61 | ): 62 | """Initialize the FlowJudge Hugging Face Transformers model. 63 | 64 | :param generation_params: Dictionary of parameters for text generation. 65 | :param flash_attn: Whether to use flash attention. 66 | :param kwargs: Additional keyword arguments, including: 67 | - _model_id: Identifier for the model. If None, uses the default model. 68 | // ... other kwargs ... 69 | """ 70 | if not HF_AVAILABLE: 71 | raise HfError( 72 | status_code=1, 73 | message="The required Hugging Face packages are not installed. " 74 | "Please install them by adding 'hf' to your extras:\n" 75 | "pip install flow-judge[hf]", 76 | ) 77 | 78 | model_id = kwargs.pop("_model_id", self._DEFAULT_MODEL_ID) 79 | 80 | if model_id != self._DEFAULT_MODEL_ID: 81 | warnings.warn( 82 | f"The model '{model_id}' is not officially supported. " 83 | f"This library is designed for the '{self._DEFAULT_MODEL_ID}' model. " 84 | "Using other models may lead to unexpected behavior, and we do not handle " 85 | "GitHub issues for unsupported models. Proceed with caution.", 86 | UserWarning, 87 | stacklevel=2, 88 | ) 89 | 90 | generation_params = GenerationParams(**(generation_params or {})) 91 | 92 | config = HfConfig( 93 | generation_params=generation_params, flash_attn=flash_attn, _model_id=model_id, **kwargs 94 | ) 95 | 96 | super().__init__(model_id, "transformers", config.generation_params, **kwargs) 97 | 98 | try: 99 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 100 | 101 | logger.info( 102 | "Downloading the model from Hugging Face Hub using hf-transfer" 103 | "for faster downloads...", 104 | ) 105 | snapshot_download(repo_id=model_id) 106 | 107 | model_kwargs = { 108 | "device_map": config.device_map, 109 | "torch_dtype": getattr(torch, config.torch_dtype), 110 | } 111 | if config.flash_attn: 112 | model_kwargs["attn_implementation"] = "flash_attention_2" 113 | 114 | # Include any additional kwargs that might be relevant for model initialization 115 | for key, value in config.kwargs.items(): 116 | if ( 117 | key not in model_kwargs 118 | and key in AutoModelForCausalLM.from_pretrained.__code__.co_varnames 119 | ): 120 | model_kwargs[key] = value 121 | 122 | self.model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) 123 | self.tokenizer = AutoTokenizer.from_pretrained(model_id) 124 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 125 | self.generation_params = generation_params.model_dump() 126 | self.config = config 127 | 128 | if self.device == "cpu": 129 | logger.warning("Running Hf on CPU may result in longer inference times.") 130 | 131 | self.batch_size = 1 # Default to 1, will be updated in batch_generate 132 | 133 | except Exception as e: 134 | raise HfError( 135 | status_code=2, 136 | message=f"An error occurred while initializing the Hugging Face model: {str(e)}\n" 137 | "Please make sure you have installed all required dependencies by adding 'hf' " 138 | "to your extras:\npip install flow-judge[hf]", 139 | ) from e 140 | 141 | def _determine_batch_size(self, prompts: list[str]) -> int: 142 | """Determine an appropriate batch size based on available GPU memory and eval_inputs. 143 | 144 | This method attempts to find the largest batch size that can be processed without 145 | running out of GPU memory. 146 | 147 | :param prompts: List of input prompts to be processed. 148 | :return: The determined optimal batch size. 149 | """ 150 | if self.device == "cpu": 151 | return 1 # Default to 1 for CPU 152 | 153 | batch_size = 1 154 | max_length = self.generation_params.get("max_model_len", 8192) 155 | max_new_tokens = self.generation_params.get("max_new_tokens", 1024) 156 | 157 | while True: 158 | try: 159 | # Prepare a batch of inputs using the longest input 160 | longest_input = max(prompts, key=lambda x: len(self.tokenizer.encode(x))) 161 | # Check if the longest input exceeds max_length 162 | input_length = len(self.tokenizer.encode(longest_input)) 163 | if input_length > max_length: 164 | logger.warning( 165 | f"Input length {input_length} exceeds max_length {max_length}. " 166 | f"Truncated inputs can result in suboptimal performance." 167 | ) 168 | 169 | inputs = [longest_input] * batch_size 170 | encoded_inputs = self.tokenizer( 171 | inputs, 172 | return_tensors="pt", 173 | padding=True, 174 | truncation=True, 175 | max_length=max_length, 176 | ).to(self.device) 177 | 178 | # Simulate generation 179 | with torch.no_grad(): 180 | _ = self.model.generate( 181 | **encoded_inputs, max_new_tokens=max_new_tokens, do_sample=False 182 | ) 183 | 184 | # If successful, double the batch size and try again 185 | batch_size *= 2 186 | torch.cuda.empty_cache() 187 | del encoded_inputs 188 | except RuntimeError as e: 189 | if "out of memory" in str(e).lower(): 190 | optimal_batch_size = max(1, batch_size // 2) 191 | logger.info(f"Automatically determined batch size: {optimal_batch_size}") 192 | torch.cuda.empty_cache() 193 | return optimal_batch_size 194 | else: 195 | raise 196 | 197 | def _prepare_generation_kwargs(self, **kwargs: Any) -> dict[str, Any]: 198 | """Combines generation params, passed kwargs, and relevant config kwargs. 199 | 200 | :param kwargs: Additional keyword arguments for generation. 201 | :return: A dictionary of prepared generation kwargs. 202 | """ 203 | generation_kwargs = {**self.generation_params, **kwargs} 204 | for key, value in self.config.kwargs.items(): 205 | if key not in generation_kwargs and key in self.model.generate.__code__.co_varnames: 206 | generation_kwargs[key] = value 207 | return generation_kwargs 208 | 209 | def _generate(self, prompt: str) -> str: 210 | """Generate a response using the FlowJudge Hugging Face Transformers model.""" 211 | chat_prompt = self.tokenizer.apply_chat_template( 212 | [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True 213 | ) 214 | inputs = self.tokenizer(chat_prompt, return_tensors="pt").to(self.device) 215 | input_length = inputs.input_ids.shape[1] 216 | 217 | generation_kwargs = self._prepare_generation_kwargs() 218 | outputs = self.model.generate(**inputs, **generation_kwargs) 219 | 220 | generated_text = self.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) 221 | return generated_text.strip() 222 | 223 | def _batch_generate( 224 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any 225 | ) -> list[str]: 226 | """Generate responses for multiple prompts using batching.""" 227 | all_results = [] 228 | 229 | self.batch_size = self._determine_batch_size(prompts) 230 | 231 | batches = [ 232 | prompts[i : i + self.batch_size] for i in range(0, len(prompts), self.batch_size) 233 | ] 234 | 235 | generation_kwargs = self._prepare_generation_kwargs(**kwargs) 236 | 237 | for batch in tqdm(batches, disable=not use_tqdm, desc="Processing batches"): 238 | chat_prompts = [ 239 | self.tokenizer.apply_chat_template( 240 | [{"role": "user", "content": prompt}], 241 | tokenize=False, 242 | add_generation_prompt=True, 243 | ) 244 | for prompt in batch 245 | ] 246 | 247 | inputs = self.tokenizer( 248 | chat_prompts, return_tensors="pt", padding=True, truncation=True 249 | ).to(self.device) 250 | input_tok_lens = [len(input) for input in inputs["input_ids"]] 251 | 252 | outputs = self.model.generate(**inputs, **generation_kwargs) 253 | 254 | batch_results = [] 255 | for output, input_tok_len in zip(outputs, input_tok_lens, strict=True): 256 | result = self.tokenizer.decode(output[input_tok_len:], skip_special_tokens=False) 257 | batch_results.append(result.strip()) 258 | 259 | all_results.extend(batch_results) 260 | 261 | return all_results 262 | 263 | 264 | class HfError(Exception): 265 | """Custom exception for Hugging Face-related errors.""" 266 | 267 | def __init__(self, status_code: int, message: str): 268 | """Initialize an HfError with a status code and message.""" 269 | self.status_code = status_code 270 | self.message = message 271 | super().__init__(self.message) 272 | -------------------------------------------------------------------------------- /flow_judge/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flowaicom/flow-judge/e56d199db4d79f184dac1e9ab2da83992acda14d/flow_judge/utils/__init__.py -------------------------------------------------------------------------------- /flow_judge/utils/prompt_formatter.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from flow_judge.metrics.metric import RubricItem 4 | 5 | USER_PROMPT_TEMPLATE = """# GOAL 6 | Your job is to evaluate a task carried out by an AI system powered by a large \ 7 | language model. 8 | 9 | You will be provided with the inputs and output of the task, as well as the evaluation criteria \ 10 | and scoring rubric. Your task is to evaluate the output of the AI system based on the evaluation \ 11 | criteria and scoring rubric provided. 12 | 13 | # INPUT 14 | Below are the inputs required for performing the task: 15 | 16 | {INPUTS} 17 | 18 | 19 | # OUTPUT 20 | Below is the output of the task: 21 | 22 | {OUTPUT} 23 | 24 | 25 | # EVALUATION CRITERIA AND SCORING RUBRIC 26 | Here are the evaluation criteria and the rubric that you need to use for evaluating the task: 27 | 28 | {EVALUATION_CRITERIA} 29 | 30 | 31 | 32 | {RUBRIC} 33 | 34 | 35 | # INSTRUCTIONS FOR THE EVALUATION 36 | 1. Understand the task and criteria: Familiarize yourself with the task to be evaluated. \ 37 | Review the evaluation criteria and scoring rubric to understand the different levels of \ 38 | performance and the descriptions for each score. 39 | 2. Review the inputs and output: Look at the inputs provided for the task. Examine the output \ 40 | generated from completing the task. 41 | 3. Compare output to score descriptions: Compare the output against the criteria and score \ 42 | descriptions in the scoring rubric. For each criterion,decide which description best matches the \ 43 | output. 44 | 4. After comparing the output to the score descriptions, pay attention to the small details that \ 45 | might impact the final score that you assign. Sometimes a small difference can dictate the final \ 46 | score. 47 | 5. Write verbal feedback justifying your evaluation that includes a detailed rationale, referring \ 48 | to specific aspects of the output and comparing them to the rubric. 49 | 6. Assign a final score based on the scoring rubric. 50 | 51 | ## FORMAT FOR THE EVALUATION 52 | - Write the verbal feedback inside tags without any additional surrounding text. 53 | - Write the numeric score inside tags, without any additional surrounding text and always \ 54 | after the feedback. 55 | 56 | Please accurately evaluate the task. Strictly adhere to the evaluation criteria and rubric.""" 57 | 58 | 59 | USER_PROMPT_NO_INPUTS_TEMPLATE = """# GOAL 60 | Your job is to evaluate a task carried out by an AI system powered by a large language model. 61 | 62 | You will be provided the output of the task, as well as the evaluation criteria \ 63 | and scoring rubric. Your task is to evaluate the output of the AI system based on the evaluation \ 64 | criteria and scoring rubric provided. 65 | 66 | # OUTPUT 67 | Below is the output of the task: 68 | 69 | {OUTPUT} 70 | 71 | 72 | # EVALUATION CRITERIA AND SCORING RUBRIC 73 | Here are the evaluation criteria and the rubric that you need to use for evaluating the task: 74 | 75 | {EVALUATION_CRITERIA} 76 | 77 | 78 | 79 | {RUBRIC} 80 | 81 | 82 | # INSTRUCTIONS FOR THE EVALUATION 83 | 1. Understand the task and criteria: Familiarize yourself with the task to be evaluated. \ 84 | Review the evaluation criteria and scoring rubric to understand the different levels of \ 85 | performance and the descriptions for each score. 86 | 2. Review the output: Examine the output generated from completing the task. 87 | 3. Compare output to score descriptions: Compare the output against the criteria and score \ 88 | descriptions in the scoring rubric. For each criterion,decide which description best matches the \ 89 | output. 90 | 4. After comparing the output to the score descriptions, pay attention to the small details that \ 91 | might impact the final score that you assign. Sometimes a small difference can dictate the final \ 92 | score. 93 | 5. Write verbal feedback justifying your evaluation that includes a detailed rationale, referring \ 94 | to specific aspects of the output and comparing them to the rubric. 95 | 6. Assign a final score based on the scoring rubric. 96 | 97 | ## FORMAT FOR THE EVALUATION 98 | - Write the verbal feedback inside tags without any additional surrounding text. 99 | - Write the numeric score inside tags, without any additional surrounding text and always \ 100 | after the feedback. 101 | 102 | Please accurately evaluate the task. Strictly adhere to the evaluation criteria and rubric.""" 103 | 104 | 105 | def format_vars(variables: list[dict[str, str]]) -> str: 106 | """Format variables for the prompt.""" 107 | var_strs = [] 108 | for var in variables: 109 | for key, value in var.items(): 110 | var_tag = key.lower().replace(" ", "_") 111 | var_strs.append(f"<{var_tag}>\n{value}\n") 112 | return "\n".join(var_strs) 113 | 114 | 115 | def format_rubric(rubric: list[RubricItem]) -> str: 116 | """Format the rubric for the prompt.""" 117 | rubric_strs = [] 118 | 119 | # Sort rubric items by score, lowest to highest 120 | sorted_rubric = sorted(rubric, key=lambda x: x.score) 121 | 122 | for item in sorted_rubric: 123 | rubric_strs.append(f"- Score {item.score}: {item.description}") 124 | return "\n".join(rubric_strs) 125 | 126 | 127 | def format_user_prompt(prompt_variables: dict[str, Any]) -> str: 128 | """Format the user prompt based on provided variables.""" 129 | if prompt_variables["INPUTS"]: 130 | return USER_PROMPT_TEMPLATE.format(**prompt_variables) 131 | else: 132 | return USER_PROMPT_NO_INPUTS_TEMPLATE.format(**prompt_variables) 133 | -------------------------------------------------------------------------------- /flow_judge/utils/result_writer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | from datetime import datetime, timezone 5 | from enum import Enum 6 | from pathlib import Path 7 | from typing import Any 8 | 9 | from pydantic import BaseModel 10 | 11 | import flow_judge 12 | from flow_judge.eval_data_types import EvalInput, EvalOutput 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def write_results_to_disk( 18 | eval_inputs: list[EvalInput], 19 | eval_outputs: list[EvalOutput], 20 | model_metadata: dict[str, Any], 21 | metric_name: str, 22 | output_dir: str | Path, 23 | append: bool = False, 24 | ) -> None: 25 | """Write evaluation results, inputs, and metadata to separate JSONL files. 26 | 27 | This function processes evaluation data and writes it to disk in a structured format. 28 | It creates separate files for metadata and results, organizing them in directories 29 | based on the metric name and model ID. 30 | 31 | Args: 32 | eval_inputs: List of evaluation inputs. 33 | eval_outputs: List of evaluation outputs. 34 | model_metadata: Dictionary containing model metadata. 35 | metric_name: Name of the metric being evaluated. 36 | output_dir: Directory to write output files. 37 | append: If True, append results to existing file. If False, overwrite. Default is False. 38 | 39 | Raises: 40 | ValueError: If inputs are invalid, empty, or lists have different lengths. 41 | KeyError: If required keys are missing from model_metadata. 42 | OSError: If there are file system related errors during writing. 43 | 44 | Note: 45 | - Ensures eval_inputs and eval_outputs have the same length. 46 | - Creates necessary directories if they don't exist. 47 | - Handles special characters in metric_name and model_id for file naming. 48 | - Overwrites existing files with the same name without warning. 49 | """ 50 | _validate_inputs(eval_inputs, eval_outputs, model_metadata, metric_name) 51 | 52 | fmt_metric_name = _format_name(metric_name) 53 | fmt_model_id = _format_name(model_metadata["model_id"]) 54 | timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S.%f")[:-3] 55 | base_filename = f"{fmt_metric_name}_{fmt_model_id}_{model_metadata['model_type']}_{timestamp}" 56 | paths = _prepare_file_paths(output_dir, fmt_metric_name, fmt_model_id, base_filename) 57 | metadata = _prepare_metadata(model_metadata, timestamp) 58 | 59 | try: 60 | _write_json_file(paths["metadata"], metadata) 61 | 62 | mode = "a" if append else "w" 63 | with paths["results"].open(mode, encoding="utf-8") as f: 64 | for eval_input, eval_output in zip(eval_inputs, eval_outputs, strict=True): 65 | result = { 66 | "sample": eval_input.model_dump(), 67 | "feedback": eval_output.feedback, 68 | "score": eval_output.score, 69 | } 70 | f.write(json.dumps(result, ensure_ascii=False) + "\n") 71 | 72 | logger.info(f"Results {'appended to' if append else 'saved to'} {paths['results']}") 73 | except OSError as e: 74 | logger.error(f"Error writing files: {e}") 75 | raise 76 | 77 | 78 | def _validate_inputs( 79 | eval_inputs: list[EvalInput], 80 | eval_outputs: list[EvalOutput], 81 | model_metadata: dict[str, Any], 82 | metric_name: str, 83 | ) -> None: 84 | """Validate input parameters for the write_results_to_disk function. 85 | 86 | Args: 87 | eval_inputs: List of evaluation inputs. 88 | eval_outputs: List of evaluation outputs. 89 | model_metadata: Dictionary containing model metadata. 90 | metric_name: Name of the metric being evaluated. 91 | 92 | Raises: 93 | ValueError: If eval_inputs or eval_outputs are empty, have different lengths, 94 | or if metric_name is empty or only whitespace. 95 | KeyError: If required keys ('model_id', 'model_type') are missing from 96 | model_metadata. 97 | 98 | Note: 99 | This function does not validate the content of eval_inputs or eval_outputs, 100 | only their presence and length. 101 | """ 102 | if not eval_inputs or not eval_outputs: 103 | raise ValueError("eval_inputs and eval_outputs cannot be empty") 104 | if len(eval_inputs) != len(eval_outputs): 105 | raise ValueError("eval_inputs and eval_outputs must have the same length") 106 | if not metric_name or not metric_name.strip(): 107 | raise ValueError("metric_name cannot be empty or only whitespace") 108 | required_keys = {"model_id", "model_type"} 109 | missing_keys = required_keys - set(model_metadata.keys()) 110 | if missing_keys: 111 | raise KeyError(f"model_metadata missing required keys: {missing_keys}") 112 | 113 | 114 | def _format_name(name: str) -> str: 115 | """Format a name for use in file paths by removing special characters. 116 | 117 | Args: 118 | name: The name to format. 119 | 120 | Returns: 121 | A formatted string safe for use in file paths. 122 | 123 | Note: 124 | This function replaces spaces with underscores, removes non-alphanumeric 125 | characters (except underscore and hyphen), and replaces non-ASCII 126 | characters with underscores. 127 | """ 128 | # Replace spaces with underscores 129 | name = name.replace(" ", "_") 130 | # Remove any character that is not alphanumeric, underscore, or hyphen 131 | name = re.sub(r"[^\w\-]", "", name) 132 | # Replace any non-ASCII character with underscore 133 | name = re.sub(r"[^\x00-\x7F]", "_", name) 134 | return name 135 | 136 | 137 | def _prepare_file_paths( 138 | output_dir: str | Path, 139 | fmt_metric_name: str, 140 | fmt_model_id: str, 141 | base_filename: str, 142 | ) -> dict[str, Path]: 143 | """Prepare file paths for metadata and results files. 144 | 145 | Args: 146 | output_dir: Base output directory. 147 | fmt_metric_name: Formatted metric name. 148 | fmt_model_id: Formatted model ID. 149 | base_filename: Base filename for output files. 150 | 151 | Returns: 152 | A dictionary containing paths for metadata and results files. 153 | 154 | Note: 155 | This function creates the necessary directories if they don't exist. 156 | It does not check if the resulting file paths already exist. 157 | """ 158 | output_dir = Path(output_dir) 159 | metric_folder = output_dir / fmt_metric_name 160 | metadata_folder = metric_folder / f"metadata_{fmt_metric_name}_{fmt_model_id}" 161 | metadata_folder.mkdir(parents=True, exist_ok=True) 162 | 163 | return { 164 | "metadata": metadata_folder / f"metadata_{base_filename}.json", 165 | "results": metric_folder / f"results_{base_filename}.jsonl", 166 | } 167 | 168 | 169 | def _prepare_metadata(model_metadata: dict[str, Any], timestamp: str) -> dict[str, Any]: 170 | """Prepare metadata dictionary for writing. 171 | 172 | Args: 173 | model_metadata: Dictionary containing model metadata. 174 | timestamp: Timestamp string. 175 | 176 | Returns: 177 | A dictionary containing prepared metadata. 178 | 179 | Note: 180 | - Adds 'library_version' and 'timestamp' to the metadata. 181 | - Converts Pydantic BaseModel instances to dictionaries. 182 | - Converts Enum instances to their values. 183 | - Does not deep copy the input model_metadata. 184 | """ 185 | metadata = { 186 | "library_version": f"{flow_judge.__version__}", 187 | "timestamp": timestamp, 188 | **model_metadata, 189 | } 190 | for key, item in metadata.items(): 191 | if isinstance(item, BaseModel): 192 | metadata[key] = item.model_dump() 193 | elif isinstance(item, Enum): 194 | metadata[key] = item.value 195 | return metadata 196 | 197 | 198 | def _write_json_file(path: Path, data: dict[str, Any]) -> None: 199 | """Write data to a JSON file. 200 | 201 | Args: 202 | path: Path to the output file. 203 | data: Data to write to the file. 204 | 205 | Raises: 206 | OSError: If there's an error writing to the file. 207 | 208 | Note: 209 | - Uses UTF-8 encoding. 210 | - Overwrites the file if it already exists. 211 | - Ensures non-ASCII characters are preserved in the output. 212 | """ 213 | with path.open("w", encoding="utf-8") as f: 214 | json.dump(data, f, ensure_ascii=False, indent=2) 215 | 216 | 217 | def _write_results_file( 218 | path: Path, eval_inputs: list[EvalInput], eval_outputs: list[EvalOutput], append: bool = False 219 | ) -> None: 220 | """Write results to a JSONL file. 221 | 222 | Args: 223 | path: Path to the output file. 224 | eval_inputs: List of evaluation inputs. 225 | eval_outputs: List of evaluation outputs. 226 | append: If True, append to the file. If False, overwrite. Default is False. 227 | 228 | Raises: 229 | OSError: If there's an error writing to the file. 230 | ValueError: If eval_inputs and eval_outputs have different lengths. 231 | 232 | Note: 233 | - Uses UTF-8 encoding. 234 | - Appends to the file if append is True, otherwise overwrites. 235 | - Each line in the file is a JSON object representing one result. 236 | - Ensures non-ASCII characters are preserved in the output. 237 | """ 238 | if len(eval_inputs) != len(eval_outputs): 239 | raise ValueError("eval_inputs and eval_outputs must have the same length") 240 | 241 | mode = "a" if append else "w" 242 | with path.open(mode, encoding="utf-8") as f: 243 | for input_data, eval_output in zip(eval_inputs, eval_outputs, strict=True): 244 | result = { 245 | "sample": input_data.model_dump(), 246 | "feedback": eval_output.feedback, 247 | "score": eval_output.score, 248 | } 249 | f.write(json.dumps(result, ensure_ascii=False) + "\n") 250 | -------------------------------------------------------------------------------- /flow_judge/utils/validators.py: -------------------------------------------------------------------------------- 1 | from flow_judge.eval_data_types import EvalInput 2 | from flow_judge.metrics.metric import CustomMetric, Metric 3 | 4 | 5 | def validate_eval_input(eval_input: EvalInput, metric: Metric | CustomMetric): 6 | """Validate that the EvalInput matches the required inputs and output in the metric.""" 7 | input_keys = {list(input_dict.keys())[0] for input_dict in eval_input.inputs} 8 | output_key = list(eval_input.output.keys())[0] 9 | required_inputs = set(metric.required_inputs) 10 | 11 | if input_keys != required_inputs: 12 | raise ValueError(f"Input keys {input_keys} do not match required inputs {required_inputs}") 13 | 14 | if metric.required_output: 15 | if not hasattr(eval_input, "output"): 16 | raise ValueError(f"Required output '{metric.required_output}' is missing") 17 | elif metric.required_output != output_key: 18 | raise ValueError( 19 | f"""Output key '{output_key}' does not match \ 20 | required output '{metric.required_output}'""" 21 | ) 22 | -------------------------------------------------------------------------------- /img/flow_judge_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flowaicom/flow-judge/e56d199db4d79f184dac1e9ab2da83992acda14d/img/flow_judge_banner.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "flow-judge" 7 | version = "0.1.2" 8 | description = "A small yet powerful LM Judge" 9 | readme = "README.md" 10 | authors = [ 11 | {name = "Bernardo Garcia", email = "bernardo@flow-ai.com"}, 12 | {name = "Karolus Sariola", email = "karolus@flow-ai.com"}, 13 | {name = "Minaam Shahid", email = "minaam@flow-ai.com"}, 14 | {name = "Tiina Vaahtio", email = "tiina@flow-ai.com"}, 15 | {name = "Alex Wegrzyn", email = "alex.wegrzyn@flow-ai.com"}, 16 | ] 17 | license = {file = "LICENSE"} 18 | classifiers = [ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Operating System :: OS Independent", 22 | ] 23 | keywords = ["LM-judge", "evaluation", "LLMs", "AI", "benchmarking"] 24 | requires-python = ">=3.10" 25 | dependencies = [ 26 | "pydantic>=2.9.1", 27 | "requests>=2.32.3", 28 | "hf-transfer>=0.1.1", 29 | "ipykernel>=6.29.0", 30 | "ipywidgets>=8.1.0", 31 | "tqdm>=4.66.1", 32 | "structlog", 33 | ] 34 | 35 | [project.optional-dependencies] 36 | dev = [ 37 | "pytest", 38 | "pre-commit", 39 | "ruff", 40 | "black", 41 | "isort", 42 | "pytest-cov", 43 | "codecov", 44 | "mypy>=1.11.2", 45 | "types-requests", 46 | "types-tqdm", 47 | "memray>=1.14.0", 48 | "pytest-memray>=1.7.0", 49 | "pytest-asyncio>=0.23.6, <0.24.0", 50 | "hypothesis" 51 | ] 52 | integrations-test = [ 53 | "llama-index", 54 | "llama-index-embeddings-huggingface" 55 | ] 56 | hf = [ 57 | "transformers>=4.45.0", 58 | "torch>=2.3.0", 59 | "bitsandbytes>=0.41.0,<=0.42.0", 60 | "accelerate>=0.34.2", 61 | ] 62 | vllm = ["vllm==0.6.2"] 63 | llamafile = [ 64 | "torch>=2.3.0", 65 | "openai>=1.51.0", 66 | ] 67 | baseten = [ 68 | "truss>=0.9.44", 69 | "openai>=1.51.0", 70 | "aiohttp>=3.10.5" 71 | ] 72 | 73 | [project.urls] 74 | Homepage = "https://github.com/flowaicom/flow-judge" 75 | 76 | [tool.setuptools] 77 | packages = ["flow_judge", "flow_judge.integrations", "flow_judge.metrics", "flow_judge.models", "flow_judge.utils"] 78 | 79 | [tool.setuptools.package-data] 80 | "flow_judge.models" = ["adapters/baseten/**/*.yaml"] 81 | 82 | [tool.setuptools_scm] 83 | version_scheme = "python-simplified-semver" 84 | 85 | [tool.ruff] 86 | line-length = 100 87 | include = ["flow_judge/**/*.py", "tests/**/*.py", "setup.py"] 88 | 89 | [tool.ruff.lint] 90 | select = ["E", "F", "I", "N", "W", "B", "C", "D"] 91 | ignore = ["D100", "D104"] 92 | 93 | [tool.ruff.lint.per-file-ignores] 94 | "__init__.py" = ["F401"] 95 | 96 | [tool.ruff.lint.pydocstyle] 97 | convention = "google" 98 | 99 | [tool.black] 100 | line-length = 100 101 | target-version = ['py311'] 102 | include = '(flow_judge/.*\.py$|tests/.*\.py$|setup\.py)' 103 | 104 | [tool.isort] 105 | profile = "black" 106 | line_length = 100 107 | src_paths = ["flow_judge", "tests"] 108 | 109 | [tool.mypy] 110 | warn_unused_configs = true 111 | warn_redundant_casts = true 112 | warn_unused_ignores = true 113 | strict_equality = true 114 | check_untyped_defs = true 115 | disallow_any_generics = true 116 | disallow_untyped_defs = false 117 | disallow_incomplete_defs = false 118 | 119 | [tool.bdist_wheel] 120 | universal = true 121 | 122 | [tool.pytest.ini_options] 123 | asyncio_mode = "auto" 124 | markers = [ 125 | "asyncio: mark test as an asyncio coroutine", 126 | "memray: marks tests to be run with memray profiling", 127 | "e2e: marks end-to-end tests", 128 | ] 129 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Tests for Flow Judge 3 | 4 | This directory contains the test suite for the Flow Judge project. 5 | 6 | ## Test Coverage 7 | 8 | Below is the current test coverage visualization for the Flow Judge project: 9 | 10 |

11 | 12 | Codecov Sunburst Graph 13 | 14 |

15 | 16 | ## Running Tests 17 | 18 | To run the entire test suite: 19 | ```sh 20 | pytest 21 | ``` 22 | To run a specific test file: 23 | ```sh 24 | pytest tests/unit/test_flow_judge.py 25 | ``` 26 | To run tests with coverage report: 27 | ```sh 28 | pytest --cov=flow_judge --cov-report=term-missing 29 | ``` 30 | 31 | ## Contributing 32 | 33 | When adding new features or modifying existing ones, please make sure to add or update the corresponding tests. This helps maintain the project's reliability and makes it easier to catch potential issues early. 34 | 35 | ## Continuous Integration 36 | 37 | Our CI pipeline automatically runs these tests on every pull request and push to the main branch. You can check the status of the latest runs in the GitHub Actions tab of the repository. 38 | -------------------------------------------------------------------------------- /tests/e2e-cloud-gpu/models/adapters/test_baseten_e2e.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import statistics 4 | from collections import Counter 5 | from pathlib import Path 6 | from typing import Any 7 | 8 | import pytest 9 | from llama_index.core import VectorStoreIndex 10 | from llama_index.core.evaluation import BatchEvalRunner 11 | from llama_index.core.llama_dataset import download_llama_dataset 12 | from pydantic import BaseModel 13 | 14 | from flow_judge import Baseten 15 | from flow_judge.integrations.llama_index import LlamaIndexFlowJudge 16 | from flow_judge.metrics import CustomMetric, RubricItem 17 | 18 | pytest_plugins = ("pytest_asyncio",) 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class TestConfig(BaseModel): 24 | """Configuration for Baseten e2e tests.""" 25 | 26 | api_key: str = os.getenv("BASETEN_API_KEY") 27 | model_id: str = os.getenv("BASETEN_MODEL_ID") 28 | webhook_url: str = os.getenv("BASETEN_WEBHOOK_URL") 29 | webhook_secret: str = os.getenv("BASETEN_WEBHOOK_SECRET") 30 | 31 | 32 | @pytest.fixture(scope="module") 33 | def test_config() -> TestConfig: 34 | """Fixture to load test configuration from environment variables. 35 | 36 | Returns: 37 | TestConfig: Configuration object for Baseten e2e tests. 38 | 39 | Raises: 40 | ValueError: If any required environment variable is missing. 41 | """ 42 | try: 43 | return TestConfig() 44 | except ValueError as e: 45 | pytest.fail(f"Missing required environment variable: {str(e)}") 46 | 47 | 48 | @pytest.fixture(scope="module") 49 | def test_cache_dir() -> Path: 50 | """Create a temporary directory for test cache. 51 | 52 | Returns: 53 | Path: Path object pointing to the temporary directory. 54 | """ 55 | import tempfile 56 | 57 | with tempfile.TemporaryDirectory(prefix="flow-judge-baseten-test-") as tmpdir: 58 | temp_path = Path(tmpdir) 59 | logger.info(f"Created temporary test cache directory: {temp_path}") 60 | yield temp_path 61 | logger.info(f"Cleaned up temporary test cache directory: {temp_path}") 62 | 63 | 64 | @pytest.fixture 65 | def correctness_metric() -> CustomMetric: 66 | """Creates a CustomMetric for evaluating the correctness of generated answers. 67 | 68 | Returns: 69 | CustomMetric: A metric object with evaluation criteria and rubric for 70 | assessing answer correctness. 71 | """ 72 | evaluation_criteria = "Is the generated answer relevant to the user query and reference answer?" 73 | rubric = [ 74 | RubricItem( 75 | score=1, 76 | description="The generated answer is not relevant to the user query " 77 | "and reference answer.", 78 | ), 79 | RubricItem( 80 | score=2, 81 | description="The generated answer is according to reference answer but" 82 | " not relevant to user query.", 83 | ), 84 | RubricItem( 85 | score=3, 86 | description="The generated answer is relevant to the user query and " 87 | "reference answer but contains mistakes.", 88 | ), 89 | RubricItem( 90 | score=4, 91 | description="The generated answer is relevant to the user query and " 92 | "has the exact same metrics as the reference answer, but it is not as concise.", 93 | ), 94 | RubricItem( 95 | score=5, 96 | description="The generated answer is relevant to the user query and " 97 | "fully correct according to the reference answer.", 98 | ), 99 | ] 100 | return CustomMetric( 101 | name="correctness", 102 | criteria=evaluation_criteria, 103 | rubric=rubric, 104 | required_inputs=["query", "reference"], 105 | required_output="response", 106 | ) 107 | 108 | 109 | def get_scores_distribution(scores: list[float]) -> dict[float, str]: 110 | """Calculates the distribution of scores as percentages. 111 | 112 | Args: 113 | scores (List[float]): A list of numerical scores. 114 | 115 | Returns: 116 | Dict[float, str]: A dictionary mapping scores to their percentage occurrence. 117 | """ 118 | score_counts = Counter(scores) 119 | total_scores = len(scores) 120 | return {score: f"{(count / total_scores) * 100:.1f}%" for score, count in score_counts.items()} 121 | 122 | 123 | def compare_distributions( 124 | actual: dict[float, str], 125 | expected: dict[float, str], 126 | tolerance: float = 20.0, 127 | ) -> bool: 128 | """Compares two score distributions within a given tolerance. 129 | 130 | Args: 131 | actual (Dict[float, str]): The actual score distribution. 132 | expected (Dict[float, str]): The expected score distribution. 133 | tolerance (float, optional): The maximum allowed difference between 134 | percentages. Defaults to 20.0. 135 | 136 | Returns: 137 | bool: True if the distributions are within the tolerance, False otherwise. 138 | """ 139 | for score in set(actual.keys()) | set(expected.keys()): 140 | actual_pct = float(actual.get(score, "0%").rstrip("%")) 141 | expected_pct = float(expected.get(score, "0%").rstrip("%")) 142 | if abs(actual_pct - expected_pct) > tolerance: 143 | return False 144 | return True 145 | 146 | 147 | async def batch_eval_runner( 148 | evaluators: dict[str, LlamaIndexFlowJudge], 149 | query_engine: Any, 150 | questions: list[str], 151 | reference: list[str] | None = None, 152 | num_workers: int = 2, 153 | ) -> dict[str, list[Any]]: 154 | """Runs batch evaluation using the provided evaluators and query engine. 155 | 156 | Args: 157 | evaluators (Dict[str, LlamaIndexFlowJudge]): Dictionary of evaluators. 158 | query_engine (Any): The query engine to use for generating responses. 159 | questions (List[str]): List of questions to evaluate. 160 | reference (Optional[List[str]], optional): List of reference answers. 161 | Defaults to None. 162 | num_workers (int, optional): Number of workers for parallel processing. 163 | Defaults to 2. 164 | 165 | Returns: 166 | Dict[str, List[Any]]: Evaluation results for each evaluator. 167 | """ 168 | batch_runner = BatchEvalRunner(evaluators, workers=num_workers, show_progress=True) 169 | return await batch_runner.aevaluate_queries( 170 | query_engine, queries=questions, reference=reference 171 | ) 172 | 173 | 174 | @pytest.mark.asyncio 175 | async def test_baseten_correctness_evaluation( 176 | test_config: TestConfig, 177 | correctness_metric: CustomMetric, 178 | test_cache_dir: Path, 179 | ) -> None: 180 | """Tests the correctness evaluation of Baseten model using LlamaIndexFlowJudge. 181 | 182 | Args: 183 | test_config (TestConfig): Test configuration object. 184 | correctness_metric (CustomMetric): The metric used for evaluation. 185 | test_cache_dir (Path): Temporary directory for test cache. 186 | 187 | Raises: 188 | AssertionError: If the evaluation score is outside the expected range or 189 | feedback is missing. 190 | """ 191 | os.environ["HF_HOME"] = str(test_cache_dir) 192 | model = Baseten( 193 | _model_id=test_config.model_id, 194 | exec_async=True, 195 | webhook_proxy_url=test_config.webhook_url, 196 | ) 197 | flow_judge_evaluator = LlamaIndexFlowJudge(model=model, metric=correctness_metric) 198 | 199 | # Download and prepare the dataset 200 | rag_dataset, documents = download_llama_dataset( 201 | "MiniTruthfulQADataset", str(test_cache_dir / "mini_truthful_qa") 202 | ) 203 | 204 | # Select a single example for evaluation 205 | example = rag_dataset.examples[0] 206 | query, reference = example.query, example.reference_answer 207 | 208 | # Generate response using Baseten model 209 | response = await model._async_generate(query) 210 | 211 | result = await flow_judge_evaluator.aevaluate( 212 | query=query, reference=reference, response=response 213 | ) 214 | 215 | assert result is not None, "Evaluation result is None" 216 | assert 2 <= int(result.score) <= 5, f"Score {result.score} is out of expected range" 217 | assert result.feedback is not None, "Feedback is missing" 218 | 219 | logger.info(f"Evaluation score: {result.score}") 220 | logger.info(f"Evaluation feedback: {result.feedback}") 221 | 222 | 223 | @pytest.mark.asyncio 224 | async def test_baseten_batch_evaluation( 225 | test_config: TestConfig, 226 | correctness_metric: CustomMetric, 227 | test_cache_dir: Path, 228 | ) -> None: 229 | """Performs a batch evaluation of queries using Baseten model and analyzes results. 230 | 231 | Args: 232 | test_config (TestConfig): Test configuration object. 233 | correctness_metric (CustomMetric): The metric used for evaluation. 234 | test_cache_dir (Path): Temporary directory for test cache. 235 | 236 | Raises: 237 | AssertionError: If the evaluation results do not meet expected criteria. 238 | """ 239 | os.environ["HF_HOME"] = str(test_cache_dir) 240 | model = Baseten( 241 | _model_id=test_config.model_id, 242 | exec_async=True, 243 | webhook_proxy_url=test_config.webhook_url, 244 | ) 245 | logger.info("Starting test_baseten_batch_evaluation") 246 | 247 | flow_judge_correctness = LlamaIndexFlowJudge(model=model, metric=correctness_metric) 248 | 249 | # Download and prepare the dataset 250 | rag_dataset, documents = download_llama_dataset( 251 | "MiniTruthfulQADataset", str(test_cache_dir / "mini_truthful_qa") 252 | ) 253 | 254 | # Create the index and query engine 255 | index = VectorStoreIndex.from_documents(documents=documents) 256 | query_engine = index.as_query_engine() 257 | 258 | # Prepare queries and references 259 | rag_subset = rag_dataset.examples[:10] 260 | queries = [example.query for example in rag_subset] 261 | references = [example.reference_answer for example in rag_subset] 262 | 263 | logger.info(f"Evaluating {len(queries)} queries") 264 | 265 | evaluators = {"correctness": flow_judge_correctness} 266 | 267 | eval_results = await batch_eval_runner( 268 | evaluators=evaluators, 269 | query_engine=query_engine, 270 | questions=queries, 271 | reference=references, 272 | ) 273 | 274 | # Check results 275 | assert "correctness" in eval_results, "Correctness evaluator results missing" 276 | assert len(eval_results["correctness"]) == len(queries), "Incomplete evaluation results" 277 | 278 | for result in eval_results["correctness"]: 279 | assert result.score is not None, "Evaluation score is missing" 280 | assert result.feedback is not None, "Evaluation feedback is missing" 281 | 282 | # Calculate score distribution 283 | scores = [result.score for result in eval_results["correctness"]] 284 | actual_distribution = get_scores_distribution(scores) 285 | logger.info(f"Actual score distribution: {actual_distribution}") 286 | 287 | # Calculate average score 288 | average_score = statistics.mean(scores) 289 | logger.info(f"Average score: {average_score:.2f}") 290 | 291 | # Assert that the average score is within an acceptable range 292 | assert ( 293 | 3.0 <= average_score <= 4.5 294 | ), f"Average score {average_score:.2f} is outside the expected range of 3.0 to 4.5" 295 | 296 | # Check that we have a variety of scores 297 | unique_scores = set(scores) 298 | assert ( 299 | len(unique_scores) >= 3 300 | ), f"Expected at least 3 different score values, but got {len(unique_scores)}" 301 | 302 | logger.info("test_baseten_batch_evaluation completed successfully") 303 | 304 | 305 | if __name__ == "__main__": 306 | pytest.main([__file__, "-v", "-s"]) 307 | -------------------------------------------------------------------------------- /tests/e2e-local/models/test_llamafile_e2e.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import tempfile 4 | from pathlib import Path 5 | 6 | import pytest 7 | 8 | from flow_judge.models.llamafile import Llamafile 9 | 10 | # Set up logging with more detail 11 | logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") 12 | logger = logging.getLogger(__name__) 13 | 14 | # Define MEMORY_LIMIT (10 GB in bytes) 15 | MEMORY_LIMIT = 10 * 1024 * 1024 * 1024 # 10 GB 16 | 17 | """ 18 | End-to-end test suite for the Llamafile class. 19 | 20 | This test suite performs comprehensive tests on the Llamafile class, 21 | including initialization, file operations, server management, command-line 22 | argument passing, and text generation. These tests interact with the actual 23 | Llamafile system and may take longer to run compared to unit tests. 24 | """ 25 | 26 | 27 | @pytest.fixture(scope="module") 28 | def test_cache_dir(): 29 | """Create a temporary directory for test cache. 30 | 31 | This fixture creates a temporary directory that is guaranteed to be 32 | writable and cleaned up after the tests. 33 | 34 | :yield: Path object pointing to the temporary directory 35 | :rtype: pathlib.Path 36 | """ 37 | with tempfile.TemporaryDirectory(prefix="flow-judge-test-") as tmpdir: 38 | temp_path = Path(tmpdir) 39 | logger.info(f"Created temporary test cache directory: {temp_path}") 40 | yield temp_path 41 | logger.info(f"Cleaned up temporary test cache directory: {temp_path}") 42 | 43 | 44 | @pytest.mark.memray(threshold=MEMORY_LIMIT) 45 | def test_llamafile_initialization(test_cache_dir): 46 | """Test Llamafile initialization with various parameters. 47 | 48 | This test verifies that a Llamafile instance is correctly initialized 49 | with custom configuration parameters, including port, context size, 50 | GPU layers, thread count, batch size, concurrent requests, and 51 | generation parameters. 52 | 53 | :param test_cache_dir: Path to the test cache directory 54 | :type test_cache_dir: pathlib.Path 55 | """ 56 | logger.info("Starting Llamafile initialization test") 57 | llamafile = Llamafile( 58 | cache_dir=str(test_cache_dir), 59 | port=9000, 60 | context_size=4096, 61 | gpu_layers=20, 62 | thread_count=4, 63 | batch_size=16, 64 | max_concurrent_requests=2, 65 | generation_params={ 66 | "temperature": 0.8, 67 | "top_p": 0.9, 68 | "max_new_tokens": 500, 69 | }, 70 | ) 71 | 72 | assert llamafile.config.model_id == "flowaicom/Flow-Judge-v0.1-Llamafile" 73 | assert llamafile.config.port == 9000 74 | assert llamafile.config.context_size == 4096 75 | assert llamafile.config.gpu_layers == 20 76 | assert llamafile.config.thread_count == 4 77 | assert llamafile.config.batch_size == 16 78 | assert llamafile.config.max_concurrent_requests == 2 79 | assert llamafile.config.generation_params.temperature == 0.8 80 | assert llamafile.config.generation_params.top_p == 0.9 81 | assert llamafile.config.generation_params.max_new_tokens == 500 82 | logger.info("Llamafile initialization test completed successfully") 83 | 84 | 85 | @pytest.mark.memray(threshold=MEMORY_LIMIT) 86 | def test_download_llamafile(test_cache_dir): 87 | """Test downloading of Llamafile. 88 | 89 | This test ensures that the Llamafile is correctly downloaded to the 90 | specified cache directory and that the downloaded file has the correct 91 | permissions (executable). 92 | 93 | :param test_cache_dir: Path to the test cache directory 94 | :type test_cache_dir: pathlib.Path 95 | """ 96 | logger.info("Starting Llamafile download test") 97 | llamafile = Llamafile(cache_dir=str(test_cache_dir)) 98 | llamafile_path = llamafile.download_llamafile() 99 | 100 | assert os.path.exists(llamafile_path), f"Llamafile not found at {llamafile_path}" 101 | assert os.access(llamafile_path, os.X_OK), f"Llamafile at {llamafile_path} is not executable" 102 | logger.info(f"Llamafile successfully downloaded to {llamafile_path}") 103 | 104 | 105 | @pytest.mark.memray(threshold=MEMORY_LIMIT) 106 | def test_build_llamafile_command(test_cache_dir): 107 | """Test building of Llamafile command. 108 | 109 | This test verifies that the Llamafile command is correctly constructed 110 | based on the provided configuration parameters. It checks for the 111 | presence of essential command-line arguments in the built command. 112 | 113 | :param test_cache_dir: Path to the test cache directory 114 | :type test_cache_dir: pathlib.Path 115 | """ 116 | logger.info("Starting Llamafile command building test") 117 | llamafile = Llamafile( 118 | cache_dir=str(test_cache_dir), 119 | port=9000, 120 | context_size=4096, 121 | gpu_layers=20, 122 | thread_count=4, 123 | batch_size=16, 124 | max_concurrent_requests=2, 125 | generation_params={"temperature": 0.8, "max_new_tokens": 500}, 126 | quantized_kv=True, 127 | flash_attn=True, 128 | ) 129 | 130 | llamafile_path = llamafile.download_llamafile() 131 | command = llamafile._build_llamafile_command(llamafile_path) 132 | 133 | expected_args = [ 134 | llamafile_path, 135 | "--port 9000", 136 | "-c 4096", 137 | "-ngl 20", 138 | "--threads 4", 139 | "-b 16", 140 | "--parallel 2", 141 | "--temp 0.8", 142 | "-n 500", 143 | "-ctk q4_0", 144 | "-ctv q4_0", 145 | "-fa", 146 | ] 147 | for arg in expected_args: 148 | assert arg in command, f"Expected argument '{arg}' not found in command" 149 | logger.info("Llamafile command successfully built and verified") 150 | 151 | 152 | @pytest.mark.memray(threshold=MEMORY_LIMIT) 153 | def test_start_stop_server(test_cache_dir): 154 | """Test starting and stopping of Llamafile server. 155 | 156 | This test verifies that the Llamafile server can be started and stopped 157 | correctly, and that the is_server_running method accurately reflects 158 | the server's state. 159 | 160 | :param test_cache_dir: Path to the test cache directory 161 | :type test_cache_dir: pathlib.Path 162 | """ 163 | logger.info("Starting Llamafile server start/stop test") 164 | llamafile = Llamafile(cache_dir=str(test_cache_dir)) 165 | 166 | try: 167 | llamafile.start_llamafile_server() 168 | assert llamafile.is_server_running(), "Server should be running but is not" 169 | logger.info("Llamafile server successfully started") 170 | 171 | llamafile.stop_llamafile_server() 172 | assert not llamafile.is_server_running(), "Server should not be running but is" 173 | logger.info("Llamafile server successfully stopped") 174 | except Exception as e: 175 | logger.error(f"Error during server start/stop test: {str(e)}") 176 | if llamafile.llamafile_process: 177 | stdout, stderr = llamafile.llamafile_process.communicate() 178 | logger.error(f"Process stdout: {stdout}") 179 | logger.error(f"Process stderr: {stderr}") 180 | raise 181 | 182 | 183 | @pytest.mark.memray(threshold=MEMORY_LIMIT) 184 | def test_generate(test_cache_dir): 185 | """Test text generation. 186 | 187 | This test verifies that the Llamafile can generate text responses 188 | to a given prompt. It checks that the generated response is a 189 | non-empty string. 190 | 191 | :param test_cache_dir: Path to the test cache directory 192 | :type test_cache_dir: pathlib.Path 193 | """ 194 | logger.info("Starting text generation test") 195 | llamafile = Llamafile(cache_dir=str(test_cache_dir)) 196 | 197 | with llamafile: 198 | try: 199 | response = llamafile._generate("Hello, world!") 200 | assert isinstance(response, str), f"Response should be a string, got {type(response)}" 201 | assert len(response) > 0, "Response should not be empty" 202 | logger.info(f"Generated response: {response}") 203 | logger.info("Text generation test completed successfully") 204 | except Exception as e: 205 | logger.error(f"Error during text generation: {str(e)}") 206 | raise 207 | 208 | 209 | @pytest.mark.memray(threshold=MEMORY_LIMIT) 210 | def test_batch_generate(test_cache_dir): 211 | """Test batch text generation. 212 | 213 | This test verifies that the Llamafile can generate text responses 214 | for multiple prompts in a batch. It checks that the number of 215 | responses matches the number of prompts and that each response 216 | is a non-empty string. 217 | 218 | :param test_cache_dir: Path to the test cache directory 219 | :type test_cache_dir: pathlib.Path 220 | """ 221 | logger.info("Starting batch text generation test") 222 | llamafile = Llamafile(cache_dir=str(test_cache_dir)) 223 | 224 | with llamafile: 225 | try: 226 | prompts = ["Hello, world!", "How are you?", "What's the weather like?"] 227 | responses = llamafile._batch_generate(prompts) 228 | assert len(responses) == len( 229 | prompts 230 | ), f"Expected {len(prompts)} responses, got {len(responses)}" 231 | for i, response in enumerate(responses): 232 | assert isinstance( 233 | response, str 234 | ), f"Response {i+1} should be a string, got {type(response)}" 235 | assert len(response) > 0, f"Response {i+1} should not be empty" 236 | logger.info(f"Response {i+1}: {response}") 237 | logger.info("Batch text generation test completed successfully") 238 | except Exception as e: 239 | logger.error(f"Error during batch text generation: {str(e)}") 240 | raise 241 | -------------------------------------------------------------------------------- /tests/unit/models/adapters/gpu.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import mock_open, patch 2 | 3 | import pytest 4 | import yaml 5 | 6 | from flow_judge.models.adapters.baseten.gpu import ( 7 | _get_gpu_key, 8 | _has_gpu_key, 9 | _update_config, 10 | ensure_gpu, 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def mock_env_setup(monkeypatch): 16 | """Mock the environment variables for testing. 17 | 18 | :returns: None 19 | :rtype: None 20 | """ 21 | monkeypatch.delenv("BASETEN_GPU", raising=False) 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "env_value, expected", 26 | [("H100", "H100"), ("h100", "h100"), ("A10G", "A10G"), ("a10g", "a10g"), (None, None)], 27 | ) 28 | def test_get_gpu_key(mock_env_setup, env_value: str | None, expected: str | None, monkeypatch): 29 | """Test the _get_gpu_key function. 30 | 31 | :param mock_env_setup: Fixture to mock environment variables. 32 | :type mock_env_setup: None 33 | :param env_value: The value to set for the BASETEN_GPU environment variable. 34 | :type env_value: str | None 35 | :param expected: The expected output. 36 | :type expected: str | None 37 | :param monkeypatch: Pytest monkeypatch fixture. 38 | :type monkeypatch: pytest.MonkeyPatch 39 | :returns: None 40 | :rtype: None 41 | """ 42 | if env_value is not None: 43 | monkeypatch.setenv("BASETEN_GPU", env_value) 44 | assert _get_gpu_key() == expected 45 | 46 | 47 | @pytest.mark.parametrize( 48 | "env_value, expected", 49 | [("H100", True), ("h100", True), ("A10G", True), ("a10g", True), (None, False)], 50 | ) 51 | def test_has_gpu_key(mock_env_setup, env_value: str | None, expected: bool, monkeypatch): 52 | """Test the _has_gpu_key function. 53 | 54 | :param mock_env_setup: Fixture to mock environment variables. 55 | :type mock_env_setup: None 56 | :param env_value: The value to set for the BASETEN_GPU environment variable. 57 | :type env_value: str | None 58 | :param expected: The expected output. 59 | :type expected: bool 60 | :param monkeypatch: Pytest monkeypatch fixture. 61 | :type monkeypatch: pytest.MonkeyPatch 62 | :returns: None 63 | :rtype: None 64 | """ 65 | if env_value is not None: 66 | monkeypatch.setenv("BASETEN_GPU", env_value) 67 | assert _has_gpu_key() == expected 68 | 69 | 70 | @pytest.mark.parametrize( 71 | "mock_file_contents, mock_env_value, expected", 72 | [ 73 | ( 74 | { 75 | "resources": {"accelerator": "test"}, 76 | "repo_id": "test", 77 | "model_metadata": {"repo_id": "test"}, 78 | }, 79 | "H100", 80 | True, 81 | ), 82 | ( 83 | { 84 | "resources": {"accelerator": "test"}, 85 | "repo_id": "test", 86 | "model_metadata": {"repo_id": "test"}, 87 | }, 88 | "A10G", 89 | True, 90 | ), 91 | ({}, "H100", False), 92 | ({}, "A10G", False), 93 | ], 94 | ) 95 | def test_update_config(monkeypatch, mock_file_contents: dict, mock_env_value: str, expected: bool): 96 | """Test the _update_config function. 97 | 98 | :param monkeypatch: Pytest monkeypatch fixture. 99 | :type monkeypatch: pytest.MonkeyPatch 100 | :param mock_file_contents: The contents to be mocked for the config.yaml file. 101 | :type mock_file_contents: dict 102 | :param mock_env_value: The value to set for the BASETEN_GPU environment variable. 103 | :type mock_env_value: str 104 | :param expected: The expected output. 105 | :type expected: bool 106 | :returns: None 107 | :rtype: None 108 | """ 109 | monkeypatch.setenv("BASETEN_GPU", mock_env_value) 110 | mock_open_obj = mock_open(read_data=yaml.dump(mock_file_contents)) 111 | with patch("builtins.open", mock_open_obj): 112 | assert _update_config() == expected 113 | 114 | 115 | @pytest.mark.parametrize( 116 | "mock_env_value, mock_input_value, expected", 117 | [ 118 | ("H100", b"y\n", True), 119 | ("H100", b"n\n", True), 120 | (None, b"y\n", False), 121 | (None, b"n\n", False), 122 | (None, b"\n", False), 123 | (None, KeyboardInterrupt, False), 124 | ], 125 | ) 126 | def test_ensure_gpu(monkeypatch, mock_env_value: str | None, mock_input_value, expected: bool): 127 | """Test the ensure_gpu function. 128 | 129 | :param monkeypatch: Pytest monkeypatch fixture. 130 | :type monkeypatch: pytest.MonkeyPatch 131 | :param mock_env_value: The value to set for the BASETEN_GPU environment variable. 132 | :type mock_env_value: str | None 133 | :param mock_input_value: The value to mock for user input. 134 | :type mock_input_value: bytes | KeyboardInterrupt 135 | :param expected: The expected output. 136 | :type expected: bool 137 | :returns: None 138 | :rtype: None 139 | """ 140 | if mock_env_value is not None: 141 | monkeypatch.setenv("BASETEN_GPU", mock_env_value) 142 | with patch("builtins.input", side_effect=[mock_input_value]): 143 | with patch("flow_judge.models.adapters.baseten.gpu._update_config") as mock_update_config: 144 | mock_update_config.return_value = True 145 | assert ensure_gpu() == expected 146 | -------------------------------------------------------------------------------- /tests/unit/models/adapters/validation.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import hmac 3 | from datetime import datetime, timedelta, timezone 4 | from unittest.mock import patch 5 | 6 | import pytest 7 | 8 | from flow_judge.models.adapters.baseten.validation import ( 9 | TIMESTAMP_TOLERANCE_SECONDS, 10 | validate_baseten_signature, 11 | ) 12 | 13 | # Mock test data for AsyncPredictResult 14 | mock_async_predict_result_data = { 15 | "model_config": {}, 16 | "request_id": "mock_request_id", 17 | "model_id": "mock_model_id", 18 | "deployment_id": "mock_deployment_id", 19 | "type": "mock_type", 20 | "time": datetime.now(timezone.utc), 21 | "data": {"mock_data": "mock_value"}, 22 | "errors": [{"mock_error": "mock_error_value"}], 23 | } 24 | 25 | # Mock webhook secret 26 | mock_webhook_secret = "mock_webhook_secret" 27 | 28 | 29 | @pytest.fixture(autouse=True) 30 | def mock_os_environ(): 31 | """Mock os.environ for BASETEN_WEBHOOK_SECRET.""" 32 | with patch.dict("os.environ", {"BASETEN_WEBHOOK_SECRET": mock_webhook_secret}): 33 | yield 34 | 35 | 36 | class TestValidateBasetenSignature: 37 | """Test suite for validate_baseten_signature.""" 38 | 39 | def test_valid_signature(self): 40 | """Test for a valid Baseten signature.""" 41 | with patch( 42 | "flow_judge.models.adapters.baseten.validation.AsyncPredictResult" 43 | ) as mock_async_predict_result: 44 | mock_async_predict_result_instance = mock_async_predict_result.return_value 45 | mock_async_predict_result_instance.model_dump_json.return_value = "mock_model_dump_json" 46 | mock_async_predict_result_instance.__dict__.update(mock_async_predict_result_data) 47 | 48 | # Mock valid signature 49 | valid_signature = hmac.digest( 50 | mock_webhook_secret.encode("utf-8"), 51 | b"mock_model_dump_json", 52 | hashlib.sha256, 53 | ).hex() 54 | 55 | assert ( 56 | validate_baseten_signature( 57 | mock_async_predict_result_instance, f"v1={valid_signature}" 58 | ) 59 | is True 60 | ) 61 | 62 | def test_invalid_signature(self): 63 | """Test for an invalid Baseten signature.""" 64 | with patch( 65 | "flow_judge.models.adapters.baseten.validation.AsyncPredictResult" 66 | ) as mock_async_predict_result: 67 | mock_async_predict_result_instance = mock_async_predict_result.return_value 68 | mock_async_predict_result_instance.model_dump_json.return_value = "mock_model_dump_json" 69 | mock_async_predict_result_instance.__dict__.update(mock_async_predict_result_data) 70 | 71 | # Mock invalid signature 72 | invalid_signature = "invalid_signature" 73 | 74 | assert ( 75 | validate_baseten_signature( 76 | mock_async_predict_result_instance, f"v1={invalid_signature}" 77 | ) 78 | is False 79 | ) 80 | 81 | def test_stale_timestamp(self): 82 | """Test for a stale timestamp.""" 83 | stale_timestamp = datetime.now(timezone.utc) - timedelta( 84 | seconds=TIMESTAMP_TOLERANCE_SECONDS + 1 85 | ) 86 | 87 | with patch( 88 | "flow_judge.models.adapters.baseten.validation.AsyncPredictResult" 89 | ) as mock_async_predict_result: 90 | mock_async_predict_result_instance = mock_async_predict_result.return_value 91 | mock_async_predict_result_instance.model_dump_json.return_value = "mock_model_dump_json" 92 | mock_async_predict_result_instance.__dict__.update(mock_async_predict_result_data) 93 | mock_async_predict_result_instance.time = stale_timestamp 94 | 95 | # Mock valid signature 96 | valid_signature = hmac.digest( 97 | mock_webhook_secret.encode("utf-8"), 98 | b"mock_model_dump_json", 99 | hashlib.sha256, 100 | ).hex() 101 | 102 | assert ( 103 | validate_baseten_signature( 104 | mock_async_predict_result_instance, f"v1={valid_signature}" 105 | ) 106 | is False 107 | ) 108 | 109 | def test_missing_webhook_secret(self, monkeypatch): 110 | """Test for a missing BASETEN_WEBHOOK_SECRET environment variable.""" 111 | monkeypatch.delenv("BASETEN_WEBHOOK_SECRET", raising=False) 112 | 113 | with patch( 114 | "flow_judge.models.adapters.baseten.validation.AsyncPredictResult" 115 | ) as mock_async_predict_result: 116 | mock_async_predict_result_instance = mock_async_predict_result.return_value 117 | mock_async_predict_result_instance.model_dump_json.return_value = "mock_model_dump_json" 118 | mock_async_predict_result_instance.__dict__.update(mock_async_predict_result_data) 119 | 120 | # Mock valid signature 121 | valid_signature = hmac.digest( 122 | mock_webhook_secret.encode("utf-8"), 123 | b"mock_model_dump_json", 124 | hashlib.sha256, 125 | ).hex() 126 | 127 | assert ( 128 | validate_baseten_signature( 129 | mock_async_predict_result_instance, f"v1={valid_signature}" 130 | ) 131 | is False 132 | ) 133 | -------------------------------------------------------------------------------- /tests/unit/models/test_llamafile_unit.py: -------------------------------------------------------------------------------- 1 | import signal 2 | import subprocess 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | 7 | from flow_judge.models.llamafile import cleanup_llamafile 8 | 9 | """ 10 | Test suite for the cleanup_llamafile function. 11 | 12 | This module contains unit tests that verify the behavior of the cleanup_llamafile 13 | function under various scenarios, including normal termination, forced termination, 14 | error handling, and edge cases. 15 | """ 16 | 17 | 18 | @pytest.fixture 19 | def mock_process(): 20 | """Create a mock process for testing. 21 | 22 | This fixture creates a MagicMock object that simulates a process, 23 | primarily for use in testing the cleanup_llamafile function. 24 | 25 | :return: A MagicMock object representing a process 26 | :rtype: unittest.mock.MagicMock 27 | 28 | The mock process has the following attributes: 29 | - pid: Set to 12345 30 | 31 | Usage: 32 | def test_example(mock_process): 33 | # Use mock_process in your test 34 | 35 | Note: 36 | This fixture is session-scoped by default. If you need a different 37 | scope, you can modify the fixture decorator, e.g., 38 | @pytest.fixture(scope="function"). 39 | """ 40 | process = MagicMock() 41 | process.pid = 12345 42 | return process 43 | 44 | 45 | @patch("os.getpgid") 46 | @patch("os.killpg") 47 | @patch("subprocess.Popen") 48 | def test_normal_termination(mock_popen, mock_killpg, mock_getpgid, mock_process): 49 | """Test the normal termination scenario of cleanup_llamafile. 50 | 51 | This test verifies that when a process terminates normally: 52 | 1. The process group ID is correctly retrieved. 53 | 2. A SIGTERM signal is sent to the process group. 54 | 3. The function waits for the process to terminate with a timeout. 55 | 56 | It uses mocking to simulate system calls and process behavior. 57 | """ 58 | mock_getpgid.return_value = 54321 59 | mock_popen.return_value = mock_process 60 | 61 | cleanup_llamafile(lambda: mock_process) 62 | 63 | mock_getpgid.assert_called_once_with(12345) 64 | mock_killpg.assert_called_once_with(54321, signal.SIGTERM) 65 | mock_process.wait.assert_called_once_with(timeout=5) 66 | 67 | 68 | @patch("os.getpgid") 69 | @patch("os.killpg") 70 | @patch("subprocess.Popen") 71 | def test_force_kill(mock_popen, mock_killpg, mock_getpgid, mock_process): 72 | """Test the force kill scenario of cleanup_llamafile. 73 | 74 | This test ensures that when a process doesn't terminate after SIGTERM: 75 | 1. A SIGTERM signal is initially sent to the process group. 76 | 2. After a timeout, a SIGKILL signal is sent to forcefully terminate the process. 77 | 78 | It simulates a process that doesn't respond to SIGTERM, requiring SIGKILL. 79 | """ 80 | mock_getpgid.return_value = 54321 81 | mock_process.wait.side_effect = subprocess.TimeoutExpired(cmd="test", timeout=5) 82 | mock_popen.return_value = mock_process 83 | 84 | cleanup_llamafile(lambda: mock_process) 85 | 86 | mock_killpg.assert_any_call(54321, signal.SIGTERM) 87 | mock_killpg.assert_any_call(54321, signal.SIGKILL) 88 | 89 | 90 | @patch("os.getpgid") 91 | @patch("subprocess.Popen") 92 | def test_os_error_fallback(mock_popen, mock_getpgid, mock_process): 93 | """Test the OS error fallback scenario of cleanup_llamafile. 94 | 95 | This test verifies the function's behavior when it encounters an OSError: 96 | 1. It attempts to get the process group ID, which raises an OSError. 97 | 2. The function falls back to terminating the individual process. 98 | 3. It waits for the process to terminate with a timeout. 99 | 100 | This test ensures the function has a proper fallback mechanism for OS-level errors. 101 | """ 102 | mock_getpgid.side_effect = OSError() 103 | mock_popen.return_value = mock_process 104 | 105 | cleanup_llamafile(lambda: mock_process) 106 | 107 | mock_process.terminate.assert_called_once() 108 | mock_process.wait.assert_called_once_with(timeout=5) 109 | 110 | 111 | def test_already_terminated(mock_process): 112 | """Test the scenario where the process is already terminated. 113 | 114 | This test ensures that the cleanup_llamafile function handles the case 115 | where the process reference is None (indicating an already terminated process) 116 | without raising any errors. 117 | 118 | No assertions are made as the function should complete without any action or error. 119 | """ 120 | mock_process.__bool__.return_value = False 121 | 122 | cleanup_llamafile(lambda: None) 123 | 124 | # No assertions needed, function should complete without error 125 | -------------------------------------------------------------------------------- /tests/unit/test_flow_judge.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from flow_judge.eval_data_types import EvalInput, EvalOutput 7 | from flow_judge.flow_judge import FlowJudge 8 | from flow_judge.metrics import RESPONSE_CORRECTNESS_BINARY, CustomMetric, RubricItem 9 | from flow_judge.models.common import BaseFlowJudgeModel 10 | from flow_judge.utils.prompt_formatter import USER_PROMPT_TEMPLATE, format_rubric, format_vars 11 | 12 | 13 | class MockFlowJudgeModel(BaseFlowJudgeModel): 14 | """Mock model for testing.""" 15 | 16 | def __init__(self, model_id, model_type, generation_params): 17 | """Initialize the mock model.""" 18 | super().__init__(model_id, model_type, generation_params) 19 | 20 | def _generate(self, prompt): 21 | """Generate a mock response.""" 22 | return "Test feedback\n1" 23 | 24 | def _batch_generate(self, prompts, use_tqdm=True): 25 | """Generate mock responses for a list of prompts.""" 26 | return ["Test feedback\n1" for _ in prompts] 27 | 28 | def generate(self, prompt): 29 | """Generate a mock response.""" 30 | return self._generate(prompt) 31 | 32 | def batch_generate(self, prompts, use_tqdm=True): 33 | """Generate mock responses for a list of prompts.""" 34 | return self._batch_generate(prompts, use_tqdm=use_tqdm) 35 | 36 | 37 | @pytest.fixture 38 | def mock_model(): 39 | """Fixture to create a mock model for testing.""" 40 | return MockFlowJudgeModel("test-model", "mock", {"temperature": 0.7}) 41 | 42 | 43 | def test_flow_judge_initialization(mock_model): 44 | """Test the initialization of FlowJudge.""" 45 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model) 46 | assert isinstance(judge, FlowJudge) 47 | assert judge.metric == RESPONSE_CORRECTNESS_BINARY 48 | assert judge.model == mock_model 49 | 50 | 51 | def test_flow_judge_initialization_invalid_metric(): 52 | """Test FlowJudge initialization with invalid metric.""" 53 | with pytest.raises(ValueError): 54 | FlowJudge(metric="invalid_metric", model=mock_model) 55 | 56 | 57 | def test_flow_judge_evaluate(mock_model): 58 | """Test the evaluate method of FlowJudge.""" 59 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model) 60 | eval_input = EvalInput( 61 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}], 62 | output={"response": "Test response"}, 63 | ) 64 | result = judge.evaluate(eval_input) 65 | assert isinstance(result, EvalOutput) 66 | assert result.feedback == "Test feedback" 67 | assert result.score == 1 68 | 69 | 70 | def test_flow_judge_batch_evaluate(mock_model): 71 | """Test the batch_evaluate method of FlowJudge.""" 72 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model) 73 | eval_inputs = [ 74 | EvalInput( 75 | inputs=[{"query": "Test query 1"}, {"reference_answer": "Test reference 1"}], 76 | output={"response": "Test response 1"}, 77 | ), 78 | EvalInput( 79 | inputs=[{"query": "Test query 2"}, {"reference_answer": "Test reference 2"}], 80 | output={"response": "Test response 2"}, 81 | ), 82 | ] 83 | results = judge.batch_evaluate(eval_inputs, save_results=False) 84 | assert len(results) == 2 85 | for result in results: 86 | assert isinstance(result, EvalOutput) 87 | assert result.feedback == "Test feedback" 88 | assert result.score == 1 89 | 90 | 91 | @pytest.mark.parametrize("save_results", [True, False]) 92 | def test_flow_judge_evaluate_save_results(mock_model, tmp_path, save_results): 93 | """Test saving results in the evaluate method.""" 94 | judge = FlowJudge( 95 | metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model, output_dir=str(tmp_path) 96 | ) 97 | eval_input = EvalInput( 98 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}], 99 | output={"response": "Test response"}, 100 | ) 101 | with patch("flow_judge.flow_judge.write_results_to_disk") as mock_write: 102 | judge.evaluate(eval_input, save_results=save_results) 103 | if save_results: 104 | mock_write.assert_called_once() 105 | else: 106 | mock_write.assert_not_called() 107 | 108 | 109 | def test_custom_metric(): 110 | """Test creating and using a custom metric.""" 111 | custom_metric = CustomMetric( 112 | name="custom_metric", 113 | criteria="Custom criteria", 114 | rubric=[RubricItem(score=0, description="Bad"), RubricItem(score=1, description="Good")], 115 | required_inputs=["custom_input"], 116 | required_output="custom_output", 117 | ) 118 | assert custom_metric.name == "custom_metric" 119 | assert custom_metric.criteria == "Custom criteria" 120 | assert len(custom_metric.rubric) == 2 121 | assert custom_metric.required_inputs == ["custom_input"] 122 | assert custom_metric.required_output == "custom_output" 123 | 124 | 125 | def test_eval_input_validation(mock_model): 126 | """Test EvalInput validation.""" 127 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model) 128 | 129 | # Valid input 130 | valid_input = EvalInput( 131 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}], 132 | output={"response": "Test response"}, 133 | ) 134 | assert judge.evaluate(valid_input) 135 | 136 | # Invalid input - missing required input 137 | invalid_input = EvalInput( 138 | inputs=[{"query": "Test query"}], output={"response": "Test response"} 139 | ) 140 | with pytest.raises(ValueError): 141 | judge.evaluate(invalid_input) 142 | 143 | # Invalid input - wrong output key 144 | invalid_output = EvalInput( 145 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}], 146 | output={"wrong_key": "Test response"}, 147 | ) 148 | with pytest.raises(ValueError): 149 | judge.evaluate(invalid_output) 150 | 151 | 152 | def test_format_vars(): 153 | """Test format_vars function.""" 154 | variables = [{"question": "What is 2+2?"}, {"context": "Math basics"}] 155 | formatted = format_vars(variables) 156 | expected = """ 157 | What is 2+2? 158 | 159 | 160 | Math basics 161 | """ 162 | assert expected == formatted 163 | 164 | 165 | def test_format_rubric(): 166 | """Test format_rubric function.""" 167 | rubric = [RubricItem(score=1, description="Good"), RubricItem(score=0, description="Poor")] 168 | formatted = format_rubric(rubric) 169 | expected = """- Score 0: Poor 170 | - Score 1: Good""" 171 | assert expected == formatted 172 | 173 | 174 | def test_format_prompt(mock_model): 175 | """Test FlowJudge._format_prompt.""" 176 | eval_input = EvalInput( 177 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}], 178 | output={"response": "Test response"}, 179 | ) 180 | 181 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model) 182 | prompt = judge._format_prompt(eval_input) 183 | 184 | expected_prompt = USER_PROMPT_TEMPLATE.format( 185 | INPUTS=format_vars(eval_input.inputs), 186 | OUTPUT=format_vars([eval_input.output]), 187 | EVALUATION_CRITERIA=RESPONSE_CORRECTNESS_BINARY.criteria, 188 | RUBRIC=format_rubric(RESPONSE_CORRECTNESS_BINARY.rubric), 189 | ) 190 | assert prompt == expected_prompt 191 | 192 | 193 | def test_eval_output_parse_fail_on_parse_error(): 194 | """Test EvalOutput.parse with fail_on_parse_error.""" 195 | # Invalid response without proper tags 196 | invalid_response = "This is an invalid response without proper tags" 197 | 198 | # Test with fail_on_parse_error=False (default behavior) 199 | result = EvalOutput.parse(invalid_response) 200 | assert isinstance(result, EvalOutput) 201 | assert result.feedback == "Error" 202 | assert result.score == -1 203 | 204 | # Test with fail_on_parse_error=True 205 | with pytest.raises(ValueError): 206 | EvalOutput.parse(invalid_response, fail_on_parse_error=True) 207 | 208 | # Test with valid response 209 | valid_response = "Good job!5" 210 | result = EvalOutput.parse(valid_response) 211 | assert isinstance(result, EvalOutput) 212 | assert result.feedback == "Good job!" 213 | assert result.score == 5 214 | 215 | 216 | @pytest.fixture(autouse=True) 217 | def cleanup(request, tmp_path): 218 | """Cleanup files and directories created during the test.""" 219 | yield 220 | shutil.rmtree(tmp_path, ignore_errors=True) 221 | -------------------------------------------------------------------------------- /tests/unit/test_metrics.py: -------------------------------------------------------------------------------- 1 | from flow_judge.metrics import RESPONSE_CORRECTNESS_BINARY, CustomMetric, RubricItem 2 | 3 | 4 | def test_response_correctness_binary(): 5 | """Test the RESPONSE_CORRECTNESS_BINARY metric.""" 6 | metric = RESPONSE_CORRECTNESS_BINARY 7 | assert metric.name == "Response Correctness (Binary)" 8 | assert len(metric.rubric) == 2 9 | assert metric.required_inputs == ["query", "reference_answer"] 10 | assert metric.required_output == "response" 11 | 12 | 13 | def test_custom_metric(): 14 | """Test the CustomMetric class.""" 15 | custom_metric = CustomMetric( 16 | name="Test Metric", 17 | criteria="Test criteria", 18 | rubric=[RubricItem(score=0, description="Bad"), RubricItem(score=1, description="Good")], 19 | required_inputs=["test_input"], 20 | required_output="test_output", 21 | ) 22 | assert custom_metric.name == "Test Metric" 23 | assert custom_metric.criteria == "Test criteria" 24 | assert len(custom_metric.rubric) == 2 25 | assert custom_metric.required_inputs == ["test_input"] 26 | assert custom_metric.required_output == "test_output" 27 | -------------------------------------------------------------------------------- /tests/unit/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from flow_judge.eval_data_types import EvalInput 4 | from flow_judge.metrics import CustomMetric, RubricItem 5 | from flow_judge.utils.prompt_formatter import ( 6 | USER_PROMPT_NO_INPUTS_TEMPLATE, 7 | USER_PROMPT_TEMPLATE, 8 | format_rubric, 9 | format_user_prompt, 10 | format_vars, 11 | ) 12 | from flow_judge.utils.validators import validate_eval_input 13 | 14 | 15 | def test_format_vars(): 16 | """Test the format_vars function.""" 17 | variables = [{"question": "What is 2+2?"}, {"context": "Math basics"}] 18 | formatted = format_vars(variables) 19 | expected = "\nWhat is 2+2?\n\n\nMath basics\n" 20 | assert formatted == expected 21 | 22 | 23 | def test_format_rubric(): 24 | """Test the format_rubric function.""" 25 | rubric = [RubricItem(score=0, description="Bad"), RubricItem(score=1, description="Good")] 26 | formatted = format_rubric(rubric) 27 | expected = "- Score 0: Bad\n- Score 1: Good" 28 | assert formatted == expected 29 | 30 | 31 | def test_format_user_prompt(): 32 | """Test the format_user_prompt function.""" 33 | # Test with inputs 34 | variables_with_inputs = { 35 | "INPUTS": "Test input", 36 | "OUTPUT": "Test output", 37 | "EVALUATION_CRITERIA": "Test criteria", 38 | "RUBRIC": "Test rubric", 39 | } 40 | formatted_with_inputs = format_user_prompt(variables_with_inputs) 41 | assert USER_PROMPT_TEMPLATE.format(**variables_with_inputs) == formatted_with_inputs 42 | 43 | # Test without inputs 44 | variables_without_inputs = { 45 | "INPUTS": "", 46 | "OUTPUT": "Test output", 47 | "EVALUATION_CRITERIA": "Test criteria", 48 | "RUBRIC": "Test rubric", 49 | } 50 | formatted_without_inputs = format_user_prompt(variables_without_inputs) 51 | assert ( 52 | USER_PROMPT_NO_INPUTS_TEMPLATE.format(**variables_without_inputs) 53 | == formatted_without_inputs 54 | ) 55 | 56 | 57 | def test_validate_eval_input(): 58 | """Test the validate_eval_input function.""" 59 | metric = CustomMetric( 60 | name="Test Metric", 61 | criteria="Test criteria", 62 | rubric=[RubricItem(score=0, description="Bad"), RubricItem(score=1, description="Good")], 63 | required_inputs=["test_input"], 64 | required_output="test_output", 65 | ) 66 | valid_input = EvalInput( 67 | inputs=[{"test_input": "Test value"}], output={"test_output": "Test output"} 68 | ) 69 | validate_eval_input(valid_input, metric) # Should not raise an exception 70 | 71 | invalid_input = EvalInput( 72 | inputs=[{"wrong_input": "Test value"}], output={"test_output": "Test output"} 73 | ) 74 | with pytest.raises(ValueError): 75 | validate_eval_input(invalid_input, metric) 76 | --------------------------------------------------------------------------------