├── .coverage
├── .github
└── workflows
│ ├── build-tag-and-publish.yml
│ ├── e2e-cloud-gpu.yml
│ ├── e2e-local.yml
│ └── test-and-lint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── examples
├── 1_quickstart.ipynb
├── 2_custom_evaluation_criteria.ipynb
├── 3_evaluation_strategies.ipynb
├── 4_llama_index_evaluators.ipynb
├── 5_evaluate_haystack_rag_pipeline.ipynb
├── 6_langchain_evaluators.ipynb
├── 7_baseten_async_quickstart.ipynb
└── sample_data
│ └── csr_assistant.json
├── flow_judge
├── __init__.py
├── eval_data_types.py
├── flow_judge.py
├── integrations
│ ├── __init__.py
│ ├── haystack.py
│ ├── langchain.py
│ └── llama_index.py
├── metrics
│ ├── __init__.py
│ ├── metric.py
│ └── presets.py
├── models
│ ├── __init__.py
│ ├── adapters
│ │ ├── base.py
│ │ └── baseten
│ │ │ ├── README.md
│ │ │ ├── adapter.py
│ │ │ ├── api_auth.py
│ │ │ ├── data_io.py
│ │ │ ├── deploy.py
│ │ │ ├── deployment
│ │ │ ├── config.yaml
│ │ │ └── model
│ │ │ │ ├── __init__.py
│ │ │ │ ├── helper.py
│ │ │ │ └── model.py
│ │ │ ├── errors.py
│ │ │ ├── gpu.py
│ │ │ ├── management.py
│ │ │ ├── token_bucket.py
│ │ │ ├── util.py
│ │ │ ├── validation.py
│ │ │ └── webhook.py
│ ├── baseten.py
│ ├── common.py
│ ├── huggingface.py
│ ├── llamafile.py
│ └── vllm.py
└── utils
│ ├── __init__.py
│ ├── prompt_formatter.py
│ ├── result_writer.py
│ └── validators.py
├── img
└── flow_judge_banner.png
├── pyproject.toml
└── tests
├── README.md
├── e2e-cloud-gpu
└── models
│ └── adapters
│ └── test_baseten_e2e.py
├── e2e-local
├── integrations
│ └── test_llama_index_e2e.py
└── models
│ └── test_llamafile_e2e.py
└── unit
├── models
├── adapters
│ ├── baseten.py
│ ├── gpu.py
│ └── validation.py
├── test_baseten.py
└── test_llamafile_unit.py
├── test_flow_judge.py
├── test_metrics.py
├── test_utils.py
└── utils
└── test_result_writer.py
/.coverage:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flowaicom/flow-judge/e56d199db4d79f184dac1e9ab2da83992acda14d/.coverage
--------------------------------------------------------------------------------
/.github/workflows/build-tag-and-publish.yml:
--------------------------------------------------------------------------------
1 | name: Build, tag and publish to PyPI and TestPyPI
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | tags:
8 | - 'v[0-9]+.[0-9]+.[0-9]+'
9 | paths-ignore:
10 | - '**.md'
11 | - 'docs/**'
12 | pull_request:
13 | branches: [main]
14 | paths-ignore:
15 | - '**.md'
16 | - 'docs/**'
17 | workflow_dispatch:
18 |
19 | env:
20 | PYTHON_VERSION: '3.11'
21 |
22 | jobs:
23 | check:
24 | name: Check workflow (and have one non-dependent job)
25 | runs-on: self-hosted
26 | steps:
27 | - name: Checkout code
28 | uses: actions/checkout@v4
29 | - name: Check for necessary files
30 | run: |
31 | if [ ! -f "pyproject.toml" ]; then
32 | echo "pyproject.toml is missing"
33 | exit 1
34 | fi
35 | echo "Build file is present"
36 |
37 | build:
38 | name: Build distribution 📦
39 | runs-on: self-hosted
40 | needs: [check]
41 | if: ${{ always() && needs.check.result == 'success' }}
42 |
43 | steps:
44 | - uses: actions/checkout@v4
45 | - name: Set up Python ${{ env.PYTHON_VERSION }}
46 | uses: actions/setup-python@v5
47 | with:
48 | python-version: ${{ env.PYTHON_VERSION }}
49 | - name: Install pypa/build
50 | run: python3 -m pip install build --user
51 | - name: Build a binary wheel and a source tarball
52 | run: python3 -m build
53 | - name: Store the distribution packages
54 | uses: actions/upload-artifact@v4
55 | with:
56 | name: python-package-distributions
57 | path: dist/
58 |
59 | publish-to-testpypi:
60 | name: Publish Python 🐍 distribution 📦 to TestPyPI
61 | needs: [build]
62 | runs-on: self-hosted
63 | if: github.event_name == 'push' && github.ref == 'refs/heads/main'
64 | environment:
65 | name: testpypi
66 | url: https://test.pypi.org/p/flow-judge
67 |
68 | permissions:
69 | id-token: write
70 |
71 | steps:
72 | - name: Download all the dists
73 | uses: actions/download-artifact@v4
74 | with:
75 | name: python-package-distributions
76 | path: dist/
77 | - name: Publish distribution 📦 to TestPyPI
78 | uses: pypa/gh-action-pypi-publish@release/v1
79 | with:
80 | repository-url: https://test.pypi.org/legacy/
81 | user: __token__
82 | password: ${{ secrets.TESTPYPI_API_TOKEN }}
83 |
84 | publish-to-pypi:
85 | name: Publish Python 🐍 distribution 📦 to PyPI
86 | needs: [build, publish-to-testpypi]
87 | runs-on: self-hosted
88 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
89 | environment:
90 | name: pypi
91 | url: https://pypi.org/p/flow-judge
92 | permissions:
93 | id-token: write
94 |
95 | steps:
96 | - name: Download all the dists
97 | uses: actions/download-artifact@v4
98 | with:
99 | name: python-package-distributions
100 | path: dist/
101 | - name: Verify tag format
102 | run: |
103 | if [[ ! "${{ github.ref }}" =~ ^refs/tags/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
104 | echo "Invalid tag format. Expected format: v*.*.*"
105 | exit 1
106 | fi
107 | - name: Publish distribution 📦 to PyPI
108 | uses: pypa/gh-action-pypi-publish@release/v1
109 | with:
110 | user: __token__
111 | password: ${{ secrets.PYPI_API_TOKEN }}
112 |
113 | github-release:
114 | name: Create GitHub Release
115 | needs: [publish-to-pypi]
116 | runs-on: self-hosted
117 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
118 |
119 | permissions:
120 | contents: write
121 | id-token: write
122 |
123 | steps:
124 | - name: Download all the dists
125 | uses: actions/download-artifact@v4
126 | with:
127 | name: python-package-distributions
128 | path: dist/
129 | - name: Sign the dists with Sigstore
130 | uses: sigstore/gh-action-sigstore-python@v2.1.1
131 | with:
132 | inputs: ./dist/*.tar.gz ./dist/*.whl
133 | - name: Create GitHub Release
134 | env:
135 | GITHUB_TOKEN: ${{ github.token }}
136 | run: |
137 | gh release create "${{ github.ref_name }}" \
138 | --repo "${{ github.repository }}" \
139 | --title "Release ${{ github.ref_name }}" \
140 | --notes "Release notes for version ${{ github.ref_name }}"
141 | - name: Upload artifact signatures to GitHub Release
142 | env:
143 | GITHUB_TOKEN: ${{ github.token }}
144 | run: gh release upload "${{ github.ref_name }}" dist/** --repo "${{ github.repository }}"
145 |
--------------------------------------------------------------------------------
/.github/workflows/e2e-cloud-gpu.yml:
--------------------------------------------------------------------------------
1 | name: E2E Test (cloud engines, GPU enabled)
2 |
3 | on:
4 | schedule:
5 | - cron: '0 0 * * 0' # Runs at 00:00 UTC every Sunday
6 | pull_request:
7 | types: [ready_for_review]
8 | branches: [ main ]
9 | workflow_dispatch:
10 |
11 | jobs:
12 |
13 | lint:
14 | runs-on: self-hosted
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Set up Python 3.11
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: '3.11'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install ruff==0.1.3 black==23.9.1 isort==5.12.0 pyupgrade==3.10.1
25 | - name: Lint with ruff
26 | run: |
27 | ruff check . --config pyproject.toml
28 | - name: Format with black
29 | run: |
30 | black --check . --config pyproject.toml
31 | - name: Sort imports with isort
32 | run: |
33 | isort --check-only . --settings-path pyproject.toml
34 | - name: Check for Python upgrades
35 | run: |
36 | files=$(git ls-files '*.py')
37 | if pyupgrade --py310-plus $files | grep -q '^---'; then
38 | echo "pyupgrade would make changes. Please run pyupgrade locally and commit the changes."
39 | exit 1
40 | fi
41 |
42 | test:
43 | needs: lint
44 | runs-on: self-hosted
45 | strategy:
46 | matrix:
47 | python-version: ['3.11']
48 |
49 | steps:
50 | - uses: actions/checkout@v4
51 | - name: Set up Python ${{ matrix.python-version }}
52 | uses: actions/setup-python@v5
53 | with:
54 | python-version: ${{ matrix.python-version }}
55 | - name: Install dependencies
56 | run: |
57 | python -m pip install --upgrade pip
58 | pip install .[dev,vllm,hf,llamafile,integrations-test,baseten]
59 | - name: Verify GPU availability
60 | run: |
61 | nvidia-smi
62 | python -c "import torch; print(torch.cuda.is_available())"
63 | - name: Test with pytest and generate coverage
64 | run: |
65 | export HF_HOME=/tmp/hf_home
66 | export TRANSFORMERS_CACHE=/tmp/hf_home
67 | export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
68 | export BASETEN_API_KEY=${{ secrets.BASETEN_API_KEY }}
69 | export BASETEN_MODEL_ID=${{ secrets.BASETEN_MODEL_ID }}
70 | export BASETEN_WEBHOOK_SECRET=${{ secrets.BASETEN_WEBHOOK_SECRET }}
71 | export BASETEN_WEBHOOK_URL=${{ secrets.BASETEN_WEBHOOK_URL }}
72 | pytest ./tests/e2e-cloud-gpu --cov=./ --junitxml=junit.xml
73 | - name: Upload coverage to Codecov
74 | uses: codecov/codecov-action@v4
75 | with:
76 | token: ${{ secrets.CODECOV_TOKEN }}
77 | fail_ci_if_error: true
78 | - name: Upload test results to Codecov
79 | if: ${{ !cancelled() }}
80 | uses: codecov/test-results-action@v1
81 | with:
82 | token: ${{ secrets.CODECOV_TOKEN }}
83 |
--------------------------------------------------------------------------------
/.github/workflows/e2e-local.yml:
--------------------------------------------------------------------------------
1 | name: E2E Test (local engines, GPU enabled)
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 |
9 | jobs:
10 |
11 | lint:
12 | runs-on: self-hosted
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Set up Python 3.11
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: '3.11'
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install ruff==0.1.3 black==23.9.1 isort==5.12.0 pyupgrade==3.10.1
23 | - name: Lint with ruff
24 | run: |
25 | ruff check . --config pyproject.toml
26 | - name: Format with black
27 | run: |
28 | black --check . --config pyproject.toml
29 | - name: Sort imports with isort
30 | run: |
31 | isort --check-only . --settings-path pyproject.toml
32 | - name: Check for Python upgrades
33 | run: |
34 | files=$(git ls-files '*.py')
35 | if pyupgrade --py310-plus $files | grep -q '^---'; then
36 | echo "pyupgrade would make changes. Please run pyupgrade locally and commit the changes."
37 | exit 1
38 | fi
39 |
40 | test:
41 | needs: lint
42 | runs-on: self-hosted
43 | strategy:
44 | matrix:
45 | python-version: ['3.10', '3.11', '3.12']
46 |
47 | steps:
48 | - uses: actions/checkout@v4
49 | - name: Set up Python ${{ matrix.python-version }}
50 | uses: actions/setup-python@v5
51 | with:
52 | python-version: ${{ matrix.python-version }}
53 | - name: Install dependencies
54 | run: |
55 | python -m pip install --upgrade pip
56 | pip install .[dev,vllm,hf,llamafile,integrations-test,baseten]
57 | - name: Verify GPU availability
58 | run: |
59 | nvidia-smi
60 | python -c "import torch; print(torch.cuda.is_available())"
61 | - name: Test with pytest and generate coverage
62 | run: |
63 | export HF_HOME=/tmp/hf_home
64 | export TRANSFORMERS_CACHE=/tmp/hf_home
65 | export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
66 | pytest ./tests/e2e-local --cov=./ --junitxml=junit.xml
67 | - name: Upload coverage to Codecov
68 | uses: codecov/codecov-action@v4
69 | with:
70 | token: ${{ secrets.CODECOV_TOKEN }}
71 | fail_ci_if_error: true
72 | - name: Upload test results to Codecov
73 | if: ${{ !cancelled() }}
74 | uses: codecov/test-results-action@v1
75 | with:
76 | token: ${{ secrets.CODECOV_TOKEN }}
77 |
--------------------------------------------------------------------------------
/.github/workflows/test-and-lint.yml:
--------------------------------------------------------------------------------
1 | name: Test and Lint
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 |
9 | jobs:
10 |
11 | lint:
12 | runs-on: self-hosted
13 | steps:
14 | - uses: actions/checkout@v4
15 | with:
16 | fetch-depth: 0 # Fetch all history for trufflehog
17 | - name: Set up Python 3.11
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: '3.11'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install ruff==0.1.3 black==23.9.1 isort==5.12.0 pyupgrade==3.10.1
25 | - name: Lint with ruff
26 | run: |
27 | ruff check . --config pyproject.toml
28 | - name: Format with black
29 | run: |
30 | black --check . --config pyproject.toml
31 | - name: Sort imports with isort
32 | run: |
33 | isort --check-only . --settings-path pyproject.toml
34 | - name: Check for Python upgrades
35 | run: |
36 | files=$(git ls-files '*.py')
37 | if pyupgrade --py310-plus $files | grep -q '^---'; then
38 | echo "pyupgrade would make changes. Please run pyupgrade locally and commit the changes."
39 | exit 1
40 | fi
41 | - name: Run TruffleHog
42 | uses: trufflesecurity/trufflehog@v3.82.6
43 | with:
44 | path: ./
45 | base: ${{ github.event.pull_request.base.sha }}
46 | head: ${{ github.event.pull_request.head.sha }}
47 | extra_args: --only-verified --exclude-globs=".venv/*"
48 |
49 | test:
50 | needs: lint
51 | runs-on: self-hosted
52 | strategy:
53 | matrix:
54 | python-version: ['3.10', '3.11', '3.12']
55 |
56 | steps:
57 | - uses: actions/checkout@v4
58 | - name: Set up Python ${{ matrix.python-version }}
59 | uses: actions/setup-python@v5
60 | with:
61 | python-version: ${{ matrix.python-version }}
62 | - name: Install dependencies
63 | run: |
64 | python -m pip install --upgrade pip
65 | pip install .[dev,vllm,hf,llamafile,integrations-test,baseten]
66 | - name: Verify GPU availability
67 | run: |
68 | nvidia-smi
69 | python -c "import torch; print(torch.cuda.is_available())"
70 | - name: Test with pytest and generate coverage
71 | run: |
72 | export HF_HOME=/tmp/hf_home
73 | export TRANSFORMERS_CACHE=/tmp/hf_home
74 | export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
75 | pytest ./tests/unit --cov=./ --junitxml=junit.xml
76 | - name: Upload coverage to Codecov
77 | uses: codecov/codecov-action@v4
78 | with:
79 | token: ${{ secrets.CODECOV_TOKEN }}
80 | fail_ci_if_error: true
81 | - name: Upload test results to Codecov
82 | if: ${{ !cancelled() }}
83 | uses: codecov/test-results-action@v1
84 | with:
85 | token: ${{ secrets.CODECOV_TOKEN }}
86 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python cache files
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # Virtual environments
6 | venv/
7 | env/
8 | .venv/
9 |
10 | # uv
11 | uv.lock
12 |
13 | # IDE settings
14 | .vscode/
15 | .idea/
16 |
17 | # Distribution / packaging
18 | dist/
19 | build/
20 | *.egg-info/
21 |
22 | # Logs
23 | *.log
24 |
25 | # Local configuration
26 | .env
27 |
28 | # Jupyter Notebook
29 | .ipynb_checkpoints
30 |
31 | # ruff cache
32 | .ruff_cache/
33 |
34 | # mypy cache
35 | .mypy_cache/
36 |
37 | # pytype cache
38 | .pytype/
39 |
40 | # pyc
41 | *.pyc
42 |
43 | # output
44 | output/
45 |
46 | # data
47 | data/
48 |
49 | .cache
50 |
51 | flake.nix
52 | flake.lock
53 | .direnv
54 |
55 | .hypothesis
56 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.4.0
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-yaml
8 | - id: check-ast
9 | - id: check-json
10 | - id: check-merge-conflict
11 | - id: detect-private-key
12 |
13 | - repo: https://github.com/astral-sh/ruff-pre-commit
14 | rev: v0.1.3
15 | hooks:
16 | - id: ruff
17 | args: [--fix, --exit-non-zero-on-fix, --config, pyproject.toml]
18 |
19 | - repo: https://github.com/psf/black
20 | rev: 23.9.1
21 | hooks:
22 | - id: black
23 | args: [--check, --config, pyproject.toml]
24 | language_version: python3.11
25 |
26 | - repo: https://github.com/PyCQA/isort
27 | rev: 5.12.0
28 | hooks:
29 | - id: isort
30 | args: [--check-only, --settings-path, pyproject.toml]
31 |
32 | - repo: https://github.com/asottile/pyupgrade
33 | rev: v3.10.1
34 | hooks:
35 | - id: pyupgrade
36 | args: [--py310-plus]
37 |
38 | - repo: https://github.com/trufflesecurity/trufflehog
39 | rev: v3.82.6
40 | hooks:
41 | - id: trufflehog
42 | name: TruffleHog
43 | description: Detect secrets in your data with TruffleHog.
44 | entry: trufflehog git file://. --only-verified --no-update --fail --exclude-globs=".venv/*"
45 | language: system
46 |
--------------------------------------------------------------------------------
/examples/7_baseten_async_quickstart.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Quickstart\n",
8 | "\n",
9 | "This tutorial demonstrates how to use the `Baseten` model class in async mode to perform language model-based evaluations using Flow-Judge-v0.1 deployed model on Baseten. For detailed instructions on how to use Baseten, visit the [Baseten readme](https://github.com/flowaicom/flow-judge/blob/main/flow_judge/models/adapters/baseten/README.md)."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Setup\n",
17 | "\n",
18 | "Let's instantiate the `Baseten` model class in async mode. The async implementation makes use of Baseten's async inference approach. See [here](https://docs.baseten.co/invoke/async).\n",
19 | "\n",
20 | "You can imagine this as *fire-and-forget* functionality. Completion requests are made to the deployed model, once data is processed and inference is complete, the output is sent to a predefined webhook. The webhook url is part of the original request. The `Flow-Judge` library then connects with the webhook and *listens* for a response. The library makes use of this approach to allow configurability for concurrent execution.\n",
21 | "\n",
22 | "Optionally Flow AI has deployed a webhook proxy that accepts this request signature and feeds-it-forward to the client. This can be found under the URL: \"https://proxy.flow-ai.dev\"\n",
23 | "\n",
24 | "### Pre-requisite\n",
25 | "\n",
26 | "1. Sign-up to [Baseten](https://www.baseten.co/)\n",
27 | "2. Generate a Baseten API Key from [here](https://app.baseten.co/settings/api_keys)\n",
28 | "3. Generate a Webhook secret from [here](https://app.baseten.co/settings/secrets)"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "### Additional Requirements\n",
36 | "\n",
37 | "Set your `Baseten API key`, `Webhook secret` and `GPU` option in the environment."
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 1,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "import os\n",
47 | "\n",
48 | "os.environ[\"BASETEN_WEBHOOK_SECRET\"] = \"your_baseten_webhook_secret\"\n",
49 | "os.environ[\"BASETEN_API_KEY\"] = \"your_baseten_api_key\"\n",
50 | "\n",
51 | "# You can optionally switch the GPU to H100.\n",
52 | "# This will deploy the FlowJudge model on H100 40GB\n",
53 | "# A10G deployment is Flow-Judge-v0.1-AWQ\n",
54 | "# H100 deployment is Flow-Judge-v0.1-FP8\n",
55 | "# !! Manually changing the hardware on Baseten's UI may cause compatibility issues !!\n",
56 | "os.environ[\"BASETEN_GPU\"] = \"A10G\""
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "### Instantiate the Baseten model\n",
64 | "\n",
65 | "Set the following required options for async execution mode of the Baseten model class: \n",
66 | "1. `exec_async=True`\n",
67 | "2. `webhook_proxy_url=https://proxy.flow-ai.dev` (or [run the proxy locally](https://github.com/flowaicom/flow-judge/blob/main/flow_judge/models/adapters/baseten/README.md))\n",
68 | "\n",
69 | "Optionally you can set the `async_batch_size` option to a value > 0 (defaults to `128`). This is the number of concurrent requests sent to the deployed model. It is associated with the concurrency goals you want to achieve and can be actively configured in Baseten's UI. For more information, see [here](https://docs.baseten.co/performance/concurrency). Our current deployment configuration allows a concurrency target of `128` and max replica of `1` for the deployed model as the default on Baseten. This means if you have max replica set to 1 on Baseten, it can accept concurrent requests of `128`. The batch size you set for the Baseten model class should be equivalent to the number of `concurrency_target * number_of_replicas`"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "from flow_judge import Baseten, AsyncFlowJudge\n",
79 | "from flow_judge.metrics import RESPONSE_FAITHFULNESS_5POINT\n",
80 | "\n",
81 | "# Async model execution\n",
82 | "model = Baseten(\n",
83 | " webhook_proxy_url=\"https://proxy.flow-ai.dev\",\n",
84 | " exec_async=True,\n",
85 | ")\n",
86 | "\n",
87 | "# Instantiate the Async Judge with the model and a metric\n",
88 | "# The library includes multiple default metrics and you can implement your own.\n",
89 | "faithfulness_judge = AsyncFlowJudge(\n",
90 | " metric=RESPONSE_FAITHFULNESS_5POINT,\n",
91 | " model=model,\n",
92 | ")"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "## Running Evaluations\n",
100 | "\n",
101 | "Let's test batched evaluations with our example csr data on the faithfulness 5 point likert.\n",
102 | "\n",
103 | "We use the `async_batch_evaluate` method from the AsyncFlowJudge class. Underneath this uses batched processing utilizing the batch_size set with the `async_batch_size` argument of the Baseten model class. If there are failures, for example with networking, the batch will process and errors will be propagated as log outputs. The output would include the successful responses."
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "# Read the sample data\n",
113 | "import json\n",
114 | "from flow_judge import EvalInput\n",
115 | "with open(\"sample_data/csr_assistant.json\", \"r\") as f:\n",
116 | " data = json.load(f)\n",
117 | "\n",
118 | "# Create a list of inputs and outputs\n",
119 | "inputs_batch = [\n",
120 | " [\n",
121 | " {\"query\": sample[\"query\"]},\n",
122 | " {\"context\": sample[\"context\"]},\n",
123 | " ]\n",
124 | " for sample in data\n",
125 | "]\n",
126 | "outputs_batch = [{\"response\": sample[\"response\"]} for sample in data]\n",
127 | "\n",
128 | "# Create a list of EvalInput\n",
129 | "eval_inputs_batch = [EvalInput(inputs=inputs, output=output) for inputs, output in zip(inputs_batch, outputs_batch)]\n",
130 | "\n",
131 | "# Run the batch evaluation\n",
132 | "results = await faithfulness_judge.async_batch_evaluate(eval_inputs_batch, save_results=False)"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "from IPython.display import Markdown, display\n",
142 | "\n",
143 | "# Visualizing the results\n",
144 | "for i, result in enumerate(results):\n",
145 | " display(Markdown(f\"__Sample {i+1}:__\"))\n",
146 | " display(Markdown(f\"__Feedback:__\\n{result.feedback}\\n\\n__Score:__\\n{result.score}\"))\n",
147 | " display(Markdown(\"---\"))"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {},
153 | "source": [
154 | "Similarly you can run a single evaluation task using the `async_evaluate` method on the `AsyncFlowJudge` class. Under the hood, this will process a single async request and attach listeners to the webhook for the response."
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "result = await faithfulness_judge.async_evaluate(eval_inputs_batch[0], save_results=False)"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "# Display the result\n",
173 | "display(Markdown(f\"__Feedback:__\\n{result.feedback}\\n\\n__Score:__\\n{result.score}\"))"
174 | ]
175 | }
176 | ],
177 | "metadata": {
178 | "kernelspec": {
179 | "display_name": "flow-judge-test",
180 | "language": "python",
181 | "name": "python3"
182 | },
183 | "language_info": {
184 | "codemirror_mode": {
185 | "name": "ipython",
186 | "version": 3
187 | },
188 | "file_extension": ".py",
189 | "mimetype": "text/x-python",
190 | "name": "python",
191 | "nbconvert_exporter": "python",
192 | "pygments_lexer": "ipython3",
193 | "version": "3.11.9"
194 | }
195 | },
196 | "nbformat": 4,
197 | "nbformat_minor": 2
198 | }
199 |
--------------------------------------------------------------------------------
/flow_judge/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import PackageNotFoundError, version
2 |
3 | from flow_judge.eval_data_types import EvalInput, EvalOutput
4 | from flow_judge.flow_judge import AsyncFlowJudge, FlowJudge
5 | from flow_judge.metrics import CustomMetric, Metric, RubricItem, list_all_metrics
6 | from flow_judge.models.common import BaseFlowJudgeModel
7 | from flow_judge.utils.prompt_formatter import format_rubric, format_user_prompt, format_vars
8 |
9 | try:
10 | __version__ = version("flow-judge")
11 | except PackageNotFoundError:
12 | # package is not installed
13 | __version__ = "unknown"
14 |
15 | __all__ = [
16 | "FlowJudge",
17 | "AsyncFlowJudge",
18 | "EvalInput",
19 | "format_vars",
20 | "format_rubric",
21 | "format_user_prompt",
22 | "RubricItem",
23 | "Metric",
24 | "CustomMetric",
25 | "BaseFlowJudgeModel",
26 | "EvalOutput",
27 | ]
28 |
29 | # Conditional imports for optional dependencies
30 | try:
31 | from flow_judge.models.huggingface import Hf
32 |
33 | __all__.append("Hf")
34 | except ImportError:
35 | Hf = None
36 |
37 | try:
38 | from flow_judge.models.vllm import Vllm
39 |
40 | __all__.append("Vllm")
41 | except ImportError:
42 | Vllm = None
43 |
44 | try:
45 | from flow_judge.models.llamafile import Llamafile
46 |
47 | __all__.append("Llamafile")
48 | except ImportError:
49 | Llamafile = None
50 |
51 | try:
52 | from flow_judge.models.baseten import Baseten
53 |
54 | __all__.append("Baseten")
55 | except ImportError:
56 | Baseten = None
57 |
58 |
59 | def get_available_models():
60 | """Return a list of available model classes based on installed extras."""
61 | models = [BaseFlowJudgeModel]
62 | if Hf is not None:
63 | models.append(Hf)
64 | if Vllm is not None:
65 | models.append(Vllm)
66 | if Llamafile is not None:
67 | models.append(Llamafile)
68 | if Baseten is not None:
69 | models.append(Baseten)
70 | return models
71 |
72 |
73 | __all__.append("get_available_models")
74 |
75 | # Add all metric names to __all__
76 | __all__ += list_all_metrics()
77 |
--------------------------------------------------------------------------------
/flow_judge/eval_data_types.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 |
4 | from pydantic import BaseModel, Field
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class EvalInput(BaseModel):
10 | """Input for evaluation."""
11 |
12 | inputs: list[dict[str, str]] = Field(default_factory=list)
13 | output: dict[str, str]
14 |
15 |
16 | class EvalOutput(BaseModel):
17 | """Output model for evaluation results."""
18 |
19 | feedback: str = Field(..., description="Feedback from the evaluation")
20 | score: int | None = Field(..., description="Numeric score from the evaluation")
21 |
22 | @classmethod
23 | def parse(cls, response: str, fail_on_parse_error: bool = False) -> "EvalOutput":
24 | """Parse the evaluation response from the judge."""
25 | try:
26 | # Compile regex patterns
27 | feedback_pattern = re.compile(r"\s*(.*?)\s*", re.DOTALL)
28 | score_pattern = re.compile(r"\s*(\d+)\s*", re.DOTALL)
29 |
30 | feedback_match = feedback_pattern.search(response)
31 | score_match = score_pattern.search(response)
32 |
33 | if not feedback_match or not score_match:
34 | raise ValueError("Failed to parse evaluation response.")
35 |
36 | feedback = feedback_match.group(1).strip()
37 | score = int(score_match.group(1).strip())
38 |
39 | return cls(feedback=feedback, score=score)
40 | except Exception as e:
41 | if fail_on_parse_error:
42 | raise ValueError(f"Failed to parse evaluation response: {e}") from e
43 | logger.warning(f"Parsing failed for response: {response}. Error: {e}")
44 | return EvalOutput(feedback="Error", score=-1)
45 |
--------------------------------------------------------------------------------
/flow_judge/flow_judge.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 |
4 | from flow_judge.eval_data_types import EvalInput, EvalOutput
5 | from flow_judge.metrics import CustomMetric, Metric
6 | from flow_judge.models.adapters.baseten.data_io import BatchResult
7 | from flow_judge.models.adapters.baseten.errors import FlowJudgeError
8 | from flow_judge.models.common import AsyncBaseFlowJudgeModel, BaseFlowJudgeModel
9 | from flow_judge.utils.prompt_formatter import format_rubric, format_user_prompt, format_vars
10 | from flow_judge.utils.result_writer import write_results_to_disk
11 | from flow_judge.utils.validators import validate_eval_input
12 |
13 | logging.basicConfig(level=logging.INFO)
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class BaseFlowJudge:
18 | """Base class for FlowJudge with common functionality."""
19 |
20 | def __init__(
21 | self,
22 | metric: Metric | CustomMetric,
23 | model: BaseFlowJudgeModel | AsyncBaseFlowJudgeModel,
24 | output_dir: str | None = "output/",
25 | ):
26 | """Initialize BaseFlowJudge with a metric and model."""
27 | if not isinstance(metric, (Metric, CustomMetric)):
28 | raise ValueError("Invalid metric type. Use Metric or CustomMetric.")
29 | self.metric = metric
30 | self.output_dir = output_dir
31 | self.model = model
32 |
33 | def _format_prompt(self, eval_input: EvalInput) -> str:
34 | """Format the prompt for a single evaluation input."""
35 | prompt_variables = {
36 | "INPUTS": format_vars(eval_input.inputs),
37 | "OUTPUT": format_vars([eval_input.output]),
38 | "EVALUATION_CRITERIA": self.metric.criteria,
39 | "RUBRIC": format_rubric(self.metric.rubric),
40 | }
41 | return format_user_prompt(prompt_variables)
42 |
43 | def _validate_inputs(self, eval_inputs: EvalInput | list[EvalInput]):
44 | """Validate required inputs and output against the metric."""
45 | if isinstance(eval_inputs, list):
46 | for eval_input in eval_inputs:
47 | validate_eval_input(eval_input, self.metric)
48 | else:
49 | validate_eval_input(eval_inputs, self.metric)
50 |
51 | def _save_results(
52 | self, eval_inputs: list[EvalInput], eval_outputs: list[EvalOutput], append: bool = False
53 | ):
54 | """Save results to disk."""
55 | logger.info(f"{'Appending' if append else 'Saving'} results to {self.output_dir}")
56 | write_results_to_disk(
57 | eval_inputs,
58 | eval_outputs,
59 | self.model.metadata,
60 | self.metric.name,
61 | self.output_dir,
62 | append=append,
63 | )
64 |
65 |
66 | class FlowJudge(BaseFlowJudge):
67 | """Synchronous FlowJudge class for evaluating AI outputs."""
68 |
69 | def __init__(
70 | self,
71 | metric: Metric | CustomMetric,
72 | model: BaseFlowJudgeModel,
73 | output_dir: str | None = "output/",
74 | ):
75 | """Initialize FlowJudge with a metric and model."""
76 | super().__init__(metric, model, output_dir)
77 | if not isinstance(model, BaseFlowJudgeModel):
78 | raise ValueError("Invalid model type. Use BaseFlowJudgeModel or its subclasses.")
79 |
80 | def evaluate(self, eval_input: EvalInput, save_results: bool = False) -> EvalOutput:
81 | """Evaluate a single EvalInput object."""
82 | try:
83 | self._validate_inputs(eval_input)
84 | prompt = self._format_prompt(eval_input)
85 | response = self.model._generate(prompt)
86 | eval_output = EvalOutput.parse(response)
87 | if save_results:
88 | self._save_results([eval_input], [eval_output])
89 | return eval_output
90 | except Exception as e:
91 | logger.error(f"Evaluation failed: {e}")
92 | raise
93 |
94 | def batch_evaluate(
95 | self,
96 | eval_inputs: list[EvalInput],
97 | use_tqdm: bool = True,
98 | save_results: bool = True,
99 | fail_on_parse_error: bool = False,
100 | ) -> list[EvalOutput]:
101 | """Batch evaluate a list of EvalInput objects."""
102 | self._validate_inputs(eval_inputs)
103 | prompts = [self._format_prompt(eval_input) for eval_input in eval_inputs]
104 | responses = self.model._batch_generate(prompts, use_tqdm=use_tqdm)
105 | eval_outputs = [
106 | EvalOutput.parse(response, fail_on_parse_error=fail_on_parse_error)
107 | for response in responses
108 | ]
109 | parse_failures = sum(1 for output in eval_outputs if output.score == -1)
110 | if save_results:
111 | self._save_results(eval_inputs, eval_outputs)
112 | if parse_failures > 0:
113 | logger.warning(f"Number of parsing failures: {parse_failures} out of {len(responses)}")
114 |
115 | return eval_outputs
116 |
117 |
118 | class AsyncFlowJudge(BaseFlowJudge):
119 | """Asynchronous FlowJudge class for evaluating AI outputs."""
120 |
121 | def __init__(
122 | self,
123 | metric: Metric | CustomMetric,
124 | model: AsyncBaseFlowJudgeModel,
125 | output_dir: str | None = "output/",
126 | ):
127 | """Initialize AsyncFlowJudge with a metric and model."""
128 | super().__init__(metric, model, output_dir)
129 | if not isinstance(model, AsyncBaseFlowJudgeModel):
130 | raise ValueError("Invalid model type. Use AsyncBaseFlowJudgeModel or its subclasses.")
131 |
132 | def _handle_batch_result(
133 | self, batch_result: BatchResult, batch_len: int, fail_on_parse_error: bool
134 | ) -> list[EvalOutput]:
135 | """Handle output parsing for batched results.
136 |
137 | Args:
138 | batch_result: The result of the batch from Baseten.
139 | batch_len: The initial batch size derived from the length of Eval Inputs.
140 | fail_on_parse_error: Flag to raise a parse error for the EvalOutput.
141 |
142 | Returns:
143 | list[EvalOutput]: A list of eval outputs with score and feedback.
144 |
145 | Note:
146 | There might be instances when downstream errors result in missing entries
147 | for the eval outputs. We implement retry strategies where we can, but in
148 | certain instances (such as network failures) errors are inevitable.
149 | To ascertain predictability, we 'fill-in' the errors with empty EvalOutputs.
150 |
151 | """
152 | eval_outputs = [EvalOutput(feedback="BasetenError", score=None)] * batch_len
153 | for output in batch_result.successful_outputs:
154 | index = output.get("index")
155 | eval_outputs[index - 1] = EvalOutput.parse(
156 | response=output["response"], fail_on_parse_error=fail_on_parse_error
157 | )
158 |
159 | # Log all downstream errors
160 | if len(batch_result.errors) > 0:
161 | logger.warning(
162 | f"Number of Baseten API errors: {len(batch_result.errors)}"
163 | f" of {batch_result.total_requests}."
164 | f" Success rate is {batch_result.success_rate}"
165 | " List of errors: "
166 | )
167 | for error in batch_result.errors:
168 | logger.warning(f"{error.error_type}: {error.error_message}")
169 |
170 | return eval_outputs
171 |
172 | async def async_evaluate(
173 | self, eval_input: EvalInput, save_results: bool = False, append: bool = False
174 | ) -> EvalOutput | None:
175 | """Evaluate a single EvalInput object asynchronously."""
176 | try:
177 | self._validate_inputs(eval_input)
178 | prompt = self._format_prompt(eval_input)
179 | result = await self.model._async_generate(prompt)
180 | response = result
181 |
182 | if isinstance(result, FlowJudgeError):
183 | logger.error(f" {result.error_type}: {result.error_message}")
184 | return
185 |
186 | eval_output = EvalOutput.parse(response)
187 | if save_results:
188 | logger.info(f"Saving result {'(append)' if append else '(overwrite)'}")
189 | await asyncio.to_thread(
190 | self._save_results, [eval_input], [eval_output], append=append
191 | )
192 | return eval_output
193 | except Exception as e:
194 | logger.error(f"Asynchronous evaluation failed: {e}")
195 | raise
196 |
197 | # TODO: figure if we want to have the parser be passed the fail_on_parse_error flag
198 | async def async_batch_evaluate(
199 | self,
200 | eval_inputs: list[EvalInput],
201 | use_tqdm: bool = True,
202 | save_results: bool = True,
203 | append: bool = False, # Change default to False
204 | fail_on_parse_error: bool = False,
205 | ) -> list[EvalOutput]:
206 | """Batch evaluate a list of EvalInput objects asynchronously."""
207 | self._validate_inputs(eval_inputs)
208 | prompts = [self._format_prompt(eval_input) for eval_input in eval_inputs]
209 | batch_result = await self.model._async_batch_generate(prompts, use_tqdm=use_tqdm)
210 |
211 | if isinstance(batch_result, BatchResult):
212 | eval_outputs = self._handle_batch_result(
213 | batch_result=batch_result,
214 | batch_len=len(eval_inputs),
215 | fail_on_parse_error=fail_on_parse_error,
216 | )
217 | else:
218 | eval_outputs = [
219 | EvalOutput.parse(response, fail_on_parse_error=fail_on_parse_error)
220 | for response in batch_result
221 | ]
222 | logger.warning(f"{eval_outputs}")
223 | parse_failures = sum(1 for output in eval_outputs if output.score and output.score == -1)
224 |
225 | if save_results:
226 | logger.info(f"Saving {len(eval_outputs)} results")
227 | for i, (eval_input, eval_output) in enumerate(
228 | zip(eval_inputs, eval_outputs, strict=True)
229 | ):
230 | await asyncio.to_thread(
231 | self._save_results,
232 | [eval_input],
233 | [eval_output],
234 | append=(append or i > 0), # Append for all but the first, unless append is True
235 | )
236 |
237 | if parse_failures > 0:
238 | logger.warning(
239 | f"Number of parsing failures: {parse_failures} out of {len(eval_outputs)}"
240 | )
241 |
242 | return eval_outputs
243 |
--------------------------------------------------------------------------------
/flow_judge/integrations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flowaicom/flow-judge/e56d199db4d79f184dac1e9ab2da83992acda14d/flow_judge/integrations/__init__.py
--------------------------------------------------------------------------------
/flow_judge/integrations/haystack.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any
3 |
4 | import numpy as np
5 | from haystack import component, default_from_dict, default_to_dict
6 | from haystack.utils import deserialize_type
7 |
8 | from flow_judge.flow_judge import EvalInput, EvalOutput, FlowJudge
9 | from flow_judge.metrics.metric import CustomMetric, Metric
10 | from flow_judge.models.common import BaseFlowJudgeModel
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | # Based on https://github.com/deepset-ai/haystack/blob/d234c75168dcb49866a6714aa232f37d56f72cab/haystack/components/evaluators/llm_evaluator.py#L354
15 |
16 |
17 | @component
18 | class HaystackFlowJudge:
19 | """A component that uses FlowJudge to evaluate inputs."""
20 |
21 | def __init__(
22 | self,
23 | metric: Metric | CustomMetric,
24 | model: BaseFlowJudgeModel,
25 | output_dir: str = "output/",
26 | progress_bar: bool = True,
27 | raise_on_failure: bool = True,
28 | save_results: bool = True,
29 | fail_on_parse_error: bool = False,
30 | ):
31 | """Construct a new FlowJudge evaluator."""
32 | if isinstance(metric, (Metric, CustomMetric)):
33 | self.metric = metric
34 | else:
35 | raise ValueError("Invalid metric type. Use Metric or CustomMetric.")
36 |
37 | if not isinstance(model, BaseFlowJudgeModel):
38 | raise ValueError("Invalid model type. Use BaseFlowJudgeModel or its subclasses.")
39 |
40 | self.model = model
41 | self.output_dir = output_dir
42 |
43 | self.judge = FlowJudge(metric=self.metric, model=self.model, output_dir=self.output_dir)
44 |
45 | # extract inputs and output from the metric
46 | self.inputs, self.outputs = self._extract_vars_from_metric(self.metric)
47 | self.validate_init_parameters(self.inputs, self.outputs)
48 | self.raise_on_failure = raise_on_failure
49 | self.progress_bar = progress_bar
50 | self.save_results = save_results
51 | self.fail_on_parse_error = fail_on_parse_error
52 |
53 | component.set_input_types(self, **dict(self.inputs))
54 |
55 | @staticmethod
56 | def _extract_vars_from_metric(
57 | metric: Metric | CustomMetric,
58 | ) -> tuple[list[tuple[str, type[list]]], list[str]]:
59 | """Extract the inputs to the component and its type from the metric.
60 |
61 | It also sets the output of the component.
62 | """
63 | eval_inputs_keys: list[str] = metric.required_inputs
64 | eval_output_key: str = metric.required_output
65 |
66 | inputs = [(key, list[str]) for key in eval_inputs_keys + [eval_output_key]]
67 |
68 | outputs = ["feedback", "score"]
69 |
70 | return inputs, outputs
71 |
72 | @staticmethod
73 | def validate_init_parameters(inputs: list[tuple[str, type[list]]], outputs: list[str]):
74 | """Validate the init parameters."""
75 | # Validate inputs
76 | if (
77 | not isinstance(inputs, list)
78 | or not all(isinstance(_input, tuple) for _input in inputs)
79 | or not all(
80 | isinstance(_input[0], str) and _input[1] is not list and len(_input) == 2
81 | for _input in inputs
82 | )
83 | ):
84 | msg = (
85 | f"FlowJudge evaluator expects inputs to \
86 | be a list of tuples. Each tuple must contain an input name and "
87 | f"type of list but received {inputs}."
88 | )
89 | raise ValueError(msg)
90 |
91 | # Validate outputs
92 | if not isinstance(outputs, list) or not all(isinstance(output, str) for output in outputs):
93 | msg = f"FlowJudge evaluator expects outputs \
94 | to be a list of str but received {outputs}."
95 | raise ValueError(msg)
96 |
97 | @component.output_types(
98 | results=list[dict[str, Any]],
99 | metadata=dict[str, Any],
100 | score=float,
101 | individual_scores=list[float],
102 | )
103 | def run(self, **inputs) -> dict[str, Any]:
104 | """Run the FlowJudge evaluator on the provided inputs."""
105 | self._validate_input_parameters(dict(self.inputs), inputs)
106 | eval_inputs: list[EvalInput] = self._prepare_inputs(inputs=inputs, metric=self.metric)
107 | eval_outputs: list[EvalOutput] = self.judge.batch_evaluate(
108 | eval_inputs,
109 | save_results=self.save_results,
110 | fail_on_parse_error=self.fail_on_parse_error,
111 | )
112 |
113 | results: list[dict[str, Any] | None] = []
114 | parsing_errors = 0
115 | for eval_output in eval_outputs:
116 | if eval_output.score != -1:
117 | result = {
118 | "feedback": eval_output.feedback,
119 | "score": eval_output.score,
120 | }
121 |
122 | results.append(result)
123 | else:
124 | results.append({"feedback": eval_output.feedback, "score": eval_output.score})
125 | parsing_errors += 1
126 |
127 | if parsing_errors > 0:
128 | msg = (
129 | f"FlowJudge failed to parse {parsing_errors} results out "
130 | f"of {len(eval_outputs)}. Score and Individual Scores are "
131 | "based on the successfully parsed results."
132 | )
133 | logger.warning(msg)
134 |
135 | metadata = self.model.metadata
136 |
137 | score = np.mean([result["score"] for result in results if result["score"] != -1])
138 | individual_scores = [float(result["score"]) for result in results if result["score"] != -1]
139 |
140 | return {
141 | "results": results,
142 | "metadata": metadata,
143 | "score": score,
144 | "individual_scores": individual_scores,
145 | }
146 |
147 | @staticmethod
148 | def _validate_input_parameters(expected: dict[str, Any], received: dict[str, Any]) -> None:
149 | """Validate the input parameters."""
150 | # Validate that all expected inputs are present in the received inputs
151 | for param in expected.keys():
152 | if param not in received:
153 | msg = f"FlowJudge evaluator expected input \
154 | parameter '{param}' but received only {received.keys()}."
155 | raise ValueError(msg)
156 |
157 | # Validate that all received inputs are lists
158 | if not all(isinstance(_input, list) for _input in received.values()):
159 | msg = (
160 | "FlowJudge evaluator expects all input values to be lists but received "
161 | f"{[type(_input) for _input in received.values()]}."
162 | )
163 | raise ValueError(msg)
164 |
165 | # Validate that all received inputs are of the same length
166 | inputs = received.values()
167 | length = len(next(iter(inputs)))
168 | if not all(len(_input) == length for _input in inputs):
169 | msg = (
170 | f"FlowJudge evaluator expects all input lists\
171 | to have the same length but received {inputs} with lengths "
172 | f"{[len(_input) for _input in inputs]}."
173 | )
174 | raise ValueError(msg)
175 |
176 | @staticmethod
177 | def _prepare_inputs(inputs: dict[str, Any], metric: Metric | CustomMetric) -> list[EvalInput]:
178 | """Prepare the inputs for the flow judge."""
179 | eval_inputs = []
180 | num_samples = len(next(iter(inputs.values())))
181 |
182 | for i in range(num_samples):
183 | input_list = []
184 | output_dict = {}
185 | for key, value_list in inputs.items():
186 | temp_dict = {}
187 | if key in metric.required_inputs:
188 | temp_dict[key] = value_list[i]
189 | input_list.append(temp_dict)
190 | elif key == metric.required_output:
191 | output_dict[key] = value_list[i]
192 |
193 | if not output_dict:
194 | raise ValueError(f"Required output '{metric.required_output}' not found in inputs.")
195 |
196 | eval_input = EvalInput(inputs=input_list, output=output_dict)
197 | eval_inputs.append(eval_input)
198 |
199 | return eval_inputs
200 |
201 | def to_dict(self) -> dict[str, Any]:
202 | """Serialize this component to a dictionary."""
203 | return default_to_dict(
204 | self,
205 | metric=self.metric,
206 | model=self.model,
207 | output_dir=self.output_dir,
208 | progress_bar=self.progress_bar,
209 | raise_on_failure=self.raise_on_failure,
210 | save_results=self.save_results,
211 | fail_on_parse_error=self.fail_on_parse_error,
212 | )
213 |
214 | @classmethod
215 | def from_dict(cls, data: dict[str, Any]) -> "HaystackFlowJudge":
216 | """Deserialize this component from a dictionary."""
217 | data["init_parameters"]["inputs"] = [
218 | (name, deserialize_type(type_)) for name, type_ in data["init_parameters"]["inputs"]
219 | ]
220 |
221 | return default_from_dict(cls, data)
222 |
--------------------------------------------------------------------------------
/flow_judge/integrations/langchain.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from collections.abc import Sequence
3 | from typing import Any
4 |
5 | from langchain.evaluation import StringEvaluator
6 |
7 | from flow_judge import AsyncFlowJudge, EvalInput, FlowJudge
8 | from flow_judge.metrics import CustomMetric, Metric
9 | from flow_judge.models import AsyncBaseFlowJudgeModel, BaseFlowJudgeModel
10 |
11 |
12 | class FlowJudgeLangChainEvaluator(StringEvaluator):
13 | """FlowJudgeLangchainEvaluator is a custom evaluator for LangChain.
14 |
15 | It uses FlowJudge to evaluate the LLM outputs.
16 | """
17 |
18 | def __init__(
19 | self, metric: Metric | CustomMetric, model: BaseFlowJudgeModel | AsyncBaseFlowJudgeModel
20 | ):
21 | """Initialize the LlamaIndexFlowJudge."""
22 | if isinstance(metric, (Metric, CustomMetric)):
23 | self.metric = metric
24 | else:
25 | raise ValueError("Invalid metric type. Use Metric or CustomMetric.")
26 |
27 | # Validate model and choose appropriate FlowJudge class
28 | if isinstance(model, (BaseFlowJudgeModel, AsyncBaseFlowJudgeModel)):
29 | self.model = model
30 | else:
31 | raise ValueError(
32 | "The model must be an instance of BaseFlowJudgeModel or AsyncBaseFlowJudgeModel."
33 | )
34 |
35 | # Determine if the model is async-capable
36 | self.is_async = hasattr(self.model, "exec_async") and self.model.exec_async
37 |
38 | # Initialize the appropriate judge based on async capability
39 | if self.is_async:
40 | self.judge = AsyncFlowJudge(metric=self.metric, model=self.model)
41 | else:
42 | self.judge = FlowJudge(metric=self.metric, model=self.model)
43 |
44 | def _prepare_eval_input(
45 | self,
46 | prediction: str,
47 | reference: str | None = None,
48 | input: str | None = None,
49 | **kwargs: Any,
50 | ) -> EvalInput:
51 | # Combine all inputs into a single dictionary
52 | all_inputs = {"prediction": prediction, "reference": reference, "input": input, **kwargs}
53 |
54 | # Prepare eval_inputs based on metric's required_inputs
55 | eval_inputs = []
56 | for req_input in self.metric.required_inputs:
57 | if req_input in all_inputs:
58 | value = all_inputs[req_input]
59 | if isinstance(value, (list, Sequence)) and not isinstance(value, str):
60 | eval_inputs.extend([{req_input: v} for v in value])
61 | else:
62 | eval_inputs.append({req_input: value})
63 |
64 | # Prepare the output
65 | output_key = self.metric.required_output
66 | output_value = all_inputs.get(
67 | output_key, prediction
68 | ) # Default to prediction if not specified
69 |
70 | return EvalInput(inputs=eval_inputs, output={output_key: output_value})
71 |
72 | def _evaluate_strings(
73 | self,
74 | prediction: str,
75 | reference: str | None = None,
76 | input: str | None = None,
77 | **kwargs: Any,
78 | ) -> dict[str, Any]:
79 | eval_input = self._prepare_eval_input(prediction, reference, input, **kwargs)
80 | result = self.judge.evaluate(eval_input, save_results=False)
81 |
82 | return {
83 | "score": result.score,
84 | "reasoning": result.feedback,
85 | }
86 |
87 | async def _aevaluate_strings(
88 | self,
89 | prediction: str,
90 | reference: str | None = None,
91 | input: str | None = None,
92 | sleep_time_in_seconds: int = 1,
93 | **kwargs: Any,
94 | ) -> dict[str, Any]:
95 | await asyncio.sleep(sleep_time_in_seconds)
96 | eval_input = self._prepare_eval_input(prediction, reference, input, **kwargs)
97 | result = await self.judge.async_evaluate(eval_input, save_results=False)
98 |
99 | return {
100 | "score": result.score,
101 | "reasoning": result.feedback,
102 | }
103 |
104 | @property
105 | def requires_input(self) -> bool:
106 | """Requires input."""
107 | return "input" in self.metric.required_inputs
108 |
109 | @property
110 | def requires_reference(self) -> bool:
111 | """Requires reference."""
112 | return "reference" in self.metric.required_inputs
113 |
114 | @property
115 | def evaluation_name(self) -> str:
116 | """Get metric name."""
117 | return f"flow_judge_{self.metric.name}"
118 |
119 | def get_required_inputs(self) -> list[str]:
120 | """Get required inputs."""
121 | return self.metric.required_inputs + [self.metric.required_output]
122 |
--------------------------------------------------------------------------------
/flow_judge/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from flow_judge.metrics.metric import CustomMetric, Metric, RubricItem
2 | from flow_judge.metrics.presets import (
3 | RESPONSE_CORRECTNESS_3POINT,
4 | RESPONSE_CORRECTNESS_5POINT,
5 | RESPONSE_CORRECTNESS_BINARY,
6 | RESPONSE_FAITHFULNESS_3POINT,
7 | RESPONSE_FAITHFULNESS_5POINT,
8 | RESPONSE_FAITHFULNESS_BINARY,
9 | RESPONSE_RELEVANCE_3POINT,
10 | RESPONSE_RELEVANCE_5POINT,
11 | RESPONSE_RELEVANCE_BINARY,
12 | list_all_metrics,
13 | )
14 |
--------------------------------------------------------------------------------
/flow_judge/metrics/metric.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class RubricItem(BaseModel):
5 | """Represents an item in the evaluation rubric."""
6 |
7 | score: int
8 | description: str
9 |
10 |
11 | class Metric(BaseModel):
12 | """Represents an evaluation metric."""
13 |
14 | name: str
15 | criteria: str
16 | rubric: list[RubricItem]
17 | required_inputs: list[str] | None = None
18 | required_output: str
19 |
20 | def print_required_keys(self):
21 | """Prints the required input and output keys."""
22 | print(f"Metric: {self.name}")
23 | print("Required inputs:", ", ".join(self.required_inputs or []))
24 | print("Required output:", self.required_output)
25 |
26 |
27 | class CustomMetric(Metric):
28 | """Represents a custom evaluation metric."""
29 |
30 | def __init__(
31 | self,
32 | name: str,
33 | criteria: str,
34 | rubric: list[RubricItem],
35 | required_inputs: list[str],
36 | required_output: str,
37 | ):
38 | """Initialize a custom metric."""
39 | super().__init__(
40 | name=name,
41 | criteria=criteria,
42 | rubric=rubric,
43 | required_inputs=required_inputs,
44 | required_output=required_output,
45 | )
46 |
--------------------------------------------------------------------------------
/flow_judge/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import AsyncBaseFlowJudgeModel, BaseFlowJudgeModel, ModelConfig, ModelType
2 | from .huggingface import Hf, HfError
3 | from .llamafile import Llamafile, LlamafileError
4 | from .vllm import Vllm, VllmError
5 |
6 | __all__ = [
7 | "AsyncBaseFlowJudgeModel",
8 | "BaseFlowJudgeModel",
9 | "ModelType",
10 | "ModelConfig",
11 | "Hf",
12 | "HfError",
13 | "Vllm",
14 | "VllmError",
15 | "Llamafile",
16 | "LlamafileError",
17 | ]
18 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 |
5 | class BaseAPIAdapter(ABC):
6 | """Base adapter layer for making remote requests to hosted models."""
7 |
8 | def __init__(self, base_url: str):
9 | """Initialize the BaseAPIAdapter.
10 |
11 | : param: base_url: The baseten deployed model url
12 | """
13 | self.base_url = base_url
14 |
15 | @abstractmethod
16 | def _fetch_response(self, request_body: dict[str, Any]) -> str:
17 | """Generate a response based on the given request."""
18 | pass
19 |
20 | @abstractmethod
21 | def _fetch_batched_response(self, request_bodies: list[dict[str, Any]]) -> list[str]:
22 | """Generate responses for multiple requests."""
23 | pass
24 |
25 |
26 | class AsyncBaseAPIAdapter(ABC):
27 | """Base adapter layer for making remote requests to hosted models."""
28 |
29 | def __init__(self, base_url: str):
30 | """Initialize the AsyncBaseAPIAdapter.
31 |
32 | : param: base_url: The baseten deployed model url
33 | """
34 | self.base_url = base_url
35 |
36 | @abstractmethod
37 | async def _async_fetch_response(self, request_body: dict[str, Any]) -> str:
38 | """Generate a response based on the given request."""
39 | pass
40 |
41 | @abstractmethod
42 | async def _async_fetch_batched_response(
43 | self, request_bodies: list[dict[str, Any]]
44 | ) -> list[str]:
45 | """Generate responses for multiple requests."""
46 | pass
47 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/README.md:
--------------------------------------------------------------------------------
1 | # Baseten remote execution
2 |
3 | ### Running Remotely on Baseten
4 |
5 | Running Flow Judge on Baseten allows you to offload processing tasks from your local machine.
6 | Using remote execution allows for running generations in parallel, multiple requests at a time.
7 | It can significantly improve throughput and reduce overall wait times, which might be useful for larger workloads.
8 |
9 | ### Execution modes
10 |
11 | - **Sync Mode**:
12 | - For smaller jobs, sync mode provides immediate feedback and is perfect for quick iterations and exploratory tasks
13 |
14 | - **Async (Batch) Mode**:
15 | - Allows to submit multiple requests simultaneously, ideal for larger workloads. It allows for parallel processing of multiple requests, making it suitable for large-scale
16 | evaluations.
17 |
18 | ## Setup
19 |
20 | ### Installation
21 |
22 | To use Baseten integration, ensure you install the optional `baseten` dependency:
23 |
24 | ```bash
25 | pip install -e .[dev,baseten]
26 | ```
27 |
28 | ### Creating Baseten account
29 |
30 | Follow the instruction on the signup page: https://app.baseten.co/signup
31 |
32 | ### Baseten API key
33 |
34 | To use Baseten, you need to create an API key in your account.
35 | Navigate to your account settings to generate and manage your API keys:
36 | https://app.baseten.co/settings/account/api_keys
37 |
38 | ### Baseten webhook secret (optional)
39 |
40 | For async (batch) execution, it is required to create webhook secret in your Baseten account.
41 | Follow the official Baseten instructions:
42 | https://docs.baseten.co/invoke/async-secure#creating-webhook-secrets
43 |
44 | ### Baseten GPU
45 |
46 | The Flow Judge model can be deployed with A10G or H100 40GB on Baseten's infrastructure.
47 | You have an option to set the `BASETEN_GPU` environment variable using either `A10G` or `H100` as the value in a notebook environment.
48 | If in an interactive environment (CLI) you will be asked if you would like to switch to H100.
49 | The FlowJudge models are then selected based on the architecture and GPU selection:
50 |
51 | A10G -> Flow-Judge-v0.1-AWQ
52 |
53 | H100 -> Flow-Judge-v0.1-FP8
54 |
55 | ## Sync execution
56 |
57 | When running Flow Judge eval, use `Baseten` class as a model
58 | (see [Quickstart](https://github.com/flowaicom/flow-judge?tab=readme-ov-file#quick-start)):
59 |
60 | ```python
61 | model=Baseten()
62 | ```
63 |
64 | That's it! :)
65 |
66 | During the first run you will be asked to provide the Baseten API key and the GPU.
67 | The key will be stored on your computer in `~/.trussrc` (see [truss](https://docs.baseten.co/truss-reference/overview)).
68 | It is used to validate if the model is already deployed and deploy it if needed.
69 |
70 | As part of the execution process we deploy [Flow Judge model](https://huggingface.co/flowaicom/Flow-Judge-v0.1-AWQ) to
71 | your Baseten account and promote it to published
72 | deployment.
73 |
74 | ## Async Execution
75 |
76 | ### Description
77 |
78 | In async mode, Baseten sends model outputs to a specified webhook address that needs to be publicly accessible on the
79 | internet.
80 |
81 | We offer a free proxy as an effortless solution that can forward responses directly to the Flow Judge,
82 | without exposing endpoint from your computer to the internet.
83 | This is the default behavior when running in batch mode.
84 |
85 | Alternatively, you can use tools like `ngrok` or `localtunnel` to expose the [same proxy](https://github.com/flowaicom/webhook-proxy) running locally on your device
86 | next to the Flow Judge.
87 |
88 | ### Flow AI proxy (hosted)
89 |
90 | For the hosted proxy there's no additional setup needed. You only need to configure the model instance to work in the
91 | async mode:
92 |
93 | ```python
94 | model=Baseten(exec_async=True, webhook_proxy_url="https://proxy.flow-ai.dev")
95 | ```
96 |
97 | Similarly to the synchronous execution, during the first run you will be asked for the API key to your Baseten account and the GPU,
98 | unless provided earlier. The key is used to validate and deploy the model to your Baseten account.
99 |
100 | Additionally, when using asynchronous execution, we verify the signature of the received webhooks payloads, as
101 | recommended by Baseten (see [official documentation](https://docs.baseten.co/invoke/async-secure)).
102 | Because of that, during first run you will be asked to provide the webhook secret. It will be stored
103 | in `~/.config/flow-judge/baseten_webhook_secret` and never leaves your device.
104 |
105 | ### Using proxy locally
106 |
107 | Currently Flow Judge does not provide a standalone endpoint to expose to the Internet. Instead, you can run an instance
108 | of our proxy on your machine and expose it's endpoint using eg. `ngrok`.
109 |
110 | 1. Download pre-built binary from the proxy releases page: https://github.com/flowaicom/webhook-proxy/releases
111 | or build it according to the instructions provided in the repository.
112 | 2. Run the proxy:
113 | ```shell
114 | ./proxy -addr=0.0.0.0:8000
115 | ```
116 | For more options and detailed instructions see the documentation provided in the proxy repository.
117 | 3. Expose the running proxy to the internet with eg. `ngrok`:
118 | ```shell
119 | ngrok localhost:8000
120 | ```
121 | The output of this command will provide you with the public URL.
122 | 4. Use the public URL when setting up the model instance in your Flow Judge implementation:
123 | ```python
124 | model=Baseten(exec_async=True, webhook_proxy_url="https://«ngrok url»")
125 | ```
126 |
127 | #### Using Docker
128 |
129 | In addition to the pre-built binaries, we also provide proxy Docker images.
130 |
131 | Run the proxy with Docker:
132 |
133 | ```shell
134 | docker pull ghcr.io/flowaicom/webhook-proxy:latest
135 | docker run --name=flowai-proxy -d -p 8000:8000 ghcr.io/flowaicom/webhook-proxy:latest
136 | ```
137 |
138 | Then continue with the process from point 3. from the list above.
139 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/api_auth.py:
--------------------------------------------------------------------------------
1 | import getpass
2 | import http
3 | import logging
4 | import os
5 |
6 | import requests
7 | import truss
8 | from truss.remote import remote_factory
9 |
10 | from .util import is_interactive
11 |
12 | logger: logging.Logger = logging.getLogger(__name__)
13 |
14 |
15 | def get_baseten_api_key() -> str | None:
16 | """Retrieve the Baseten API key from environment or config file.
17 |
18 | :return: The Baseten API key if found, None otherwise.
19 | :rtype: str | None
20 | """
21 | logger.debug("Attempting to retrieve Baseten API key")
22 | api_key: str | None = os.environ.get("BASETEN_API_KEY")
23 | if api_key:
24 | logger.debug("API key found in environment variables")
25 | if len(api_key) != 41:
26 | logger.warning(
27 | "Warning: Baseten API key might be incorrect. "
28 | "The length should be exactly 41 characters."
29 | )
30 | return api_key
31 |
32 | logger.debug("API key not found in environment, checking remote config")
33 | try:
34 | c: remote_factory.RemoteConfig | None = remote_factory.RemoteFactory.load_remote_config(
35 | "baseten"
36 | )
37 | if c is not None:
38 | api_key = c.configs.get("api_key")
39 | if api_key:
40 | logger.debug("API key found in remote config")
41 | if len(api_key) != 41:
42 | logger.warning(
43 | "Warning: Baseten API key might be incorrect. "
44 | "The length should be exactly 41 characters."
45 | )
46 | os.environ["BASETEN_API_KEY"] = api_key
47 | return api_key
48 | except Exception as e:
49 | logger.error(f"Error loading remote config: {e}")
50 |
51 | logger.debug("API key not found")
52 | return None
53 |
54 |
55 | def _validate_auth_status(api_key: str | None = None) -> bool:
56 | """Validate the authentication status with Baseten.
57 |
58 | :param api_key: Optional API key. If None, it will try to obtain API key from env or .trussrc.
59 | :return: True if authentication is successful, False otherwise.
60 | :rtype: bool
61 | """
62 | logger.debug("Validating authentication status")
63 |
64 | api_key: str | None = api_key or get_baseten_api_key()
65 | if not api_key:
66 | logger.warning("No API key available for authentication")
67 | return False
68 |
69 | try:
70 | response: requests.Response = requests.get(
71 | "https://api.baseten.co/v1/models",
72 | headers={"Authorization": f"Api-Key {api_key}"},
73 | timeout=10,
74 | )
75 | is_valid: bool = response.status_code == http.HTTPStatus.OK
76 | logger.debug(f"Authentication status: {'valid' if is_valid else 'invalid'}")
77 | return is_valid
78 | except requests.RequestException as e:
79 | logger.error(f"Error validating auth status: {e}")
80 | return False
81 |
82 |
83 | def _attempt_truss_auth_with_env_api_key() -> None:
84 | """Authenticate Truss with the API key from environment variable.
85 |
86 | :return: None
87 | """
88 | logger.debug("Authenticating Truss with environment API key")
89 | env_api_key: str | None = os.environ.get("BASETEN_API_KEY")
90 | if not env_api_key:
91 | logger.debug("Baseten API key not available in env var")
92 | return
93 |
94 | logger.debug("Logging in to Truss with environment API key")
95 | truss.login(env_api_key)
96 |
97 |
98 | def _print_noninteractive_prompt() -> None:
99 | """Print a API key prompt for non-interactive environments."""
100 | print("Set the Baseten API key in `BASETEN_API_KEY` environment variable " "and run again:")
101 | print("```")
102 | print('os.environ["BASETEN_API_KEY"] = "«your API key»"')
103 | print("```")
104 |
105 |
106 | def _print_general_prompt() -> None:
107 | """Print a API key prompt information."""
108 | print("To run Flow Judge remotely with Baseten, signup and generate API key")
109 | print(" ➡️ Signup: https://app.baseten.co/signup")
110 | print(" ➡️ API keys: https://app.baseten.co/settings/api_keys")
111 | print(" ➡️ Docs: https://docs.baseten.co/quickstart#setup\n")
112 |
113 |
114 | def _validate_entered_key(key: str) -> bool:
115 | """Checks if the key entered by the user in interactive environments is valid.
116 |
117 | :param key: The key entered by the user
118 | :return: True if the key entered by the user is valid, False otherwise.
119 | :rtype: bool
120 | """
121 | try:
122 | logger.debug("Attempting to log in with provided API key")
123 | truss.login(key)
124 | except Exception as e:
125 | logger.error(f"Error during login: {e}")
126 | print("An error occurred during login. Please try again.")
127 | return False
128 |
129 | if not _validate_auth_status(key):
130 | logger.warning("Invalid Baseten API key")
131 | print("Invalid Baseten API key, try again.")
132 | return False
133 |
134 | logger.info("Baseten authentication successful")
135 | return True
136 |
137 |
138 | def ensure_baseten_authentication() -> bool:
139 | """Attempts to obtain and validate the Baseten API key from the user.
140 |
141 | :return: True if authentication is successful, False otherwise.
142 | """
143 | # Checks if truss is already authenticated or if the key is provided in the env to authenticate
144 | _attempt_truss_auth_with_env_api_key()
145 |
146 | # Checks if authentication above succeeded
147 | if _validate_auth_status():
148 | logger.info("Baseten authenticated")
149 | return True
150 |
151 | logger.warning("Baseten authentication failed or not initialized")
152 | _print_general_prompt()
153 |
154 | # In non-interactive environments we return early because their inputs (env vars)
155 | # are processed at the very beginning of this function.
156 | if not is_interactive():
157 | logger.info("Non-interactive environment detected")
158 | _print_noninteractive_prompt()
159 | return False
160 |
161 | logger.info("Prompting for API key in interactive environment")
162 | while True:
163 | key: str = getpass.getpass("Baseten API key (hidden): ")
164 | if not key:
165 | logger.warning("Empty API key entered")
166 | print("Input is empty, try again.")
167 | continue
168 |
169 | print(key)
170 |
171 | if len(key) != 41:
172 | logger.warning(
173 | "Warning: Baseten API key might be incorrect. "
174 | "The length should be exactly 41 characters."
175 | )
176 |
177 | if _validate_entered_key(key):
178 | logger.info("Login successful")
179 | return True
180 |
181 | return False
182 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/data_io.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field, field_validator
2 | from typing_extensions import TypedDict
3 |
4 | from flow_judge.models.adapters.baseten.errors import FlowJudgeError
5 |
6 |
7 | class Message(TypedDict):
8 | """Represents a single request message for the Baseten API.
9 |
10 | Note:
11 | This class uses TypedDict for strict type checking. Ensure all fields
12 | are provided when instantiating. The 'id' field is crucial for tracking
13 | and error reporting throughout the evaluation process.
14 |
15 | Warning:
16 | Do not include sensitive information in the 'content' field, as it may
17 | be logged or stored for debugging purposes.
18 | """
19 |
20 | id: str
21 | index: int
22 | prompt: str
23 | response: str
24 |
25 |
26 | class BatchResult(BaseModel):
27 | """Represents the result of a batch evaluation process.
28 |
29 | This class contains both successful outputs and errors encountered during the
30 | evaluation process, as well as metadata about the batch operation.
31 |
32 | Attributes:
33 | successful_outputs (List[Message]): List of successful evaluation outputs.
34 | errors (List[FlowJudgeError]): List of errors encountered during evaluation.
35 | total_requests (int): Total number of requests processed in the batch.
36 | success_rate (float): Rate of successful evaluations (0.0 to 1.0).
37 |
38 | Note:
39 | The success_rate is calculated as (len(successful_outputs) / total_requests).
40 | Be cautious when interpreting results with a low success rate, as it may
41 | indicate systemic issues with the evaluation process or input data.
42 | """
43 |
44 | successful_outputs: list[Message] = Field(
45 | default_factory=list, description="List of successful evaluation outputs"
46 | )
47 | errors: list[FlowJudgeError] = Field(
48 | default_factory=list, description="List of errors encountered during evaluation"
49 | )
50 | total_requests: int = Field(..., description="Total number of requests processed")
51 | success_rate: float = Field(..., description="Rate of successful evaluations")
52 |
53 | @field_validator("total_requests")
54 | @classmethod
55 | def check_positive_total_requests(cls, v):
56 | """Placeholder."""
57 | if v < 0:
58 | raise ValueError("total_requests must be positive")
59 | return v
60 |
61 | @field_validator("success_rate")
62 | @classmethod
63 | def check_success_rate_range(cls, v):
64 | """Placeholder."""
65 | if not 0 <= v <= 1:
66 | raise ValueError("success_rate must be between 0 and 1")
67 | return v
68 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/deploy.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import TypedDict
4 |
5 | import requests
6 | import truss
7 | from truss.api.definitions import ModelDeployment
8 | from truss.remote.baseten.error import ApiError
9 |
10 | from flow_judge.models.adapters.baseten.management import sync_set_scale_down
11 |
12 | from .api_auth import ensure_baseten_authentication, get_baseten_api_key
13 | from .gpu import ensure_gpu
14 | from .webhook import ensure_baseten_webhook_secret
15 |
16 | logger: logging.Logger = logging.getLogger(__name__)
17 |
18 |
19 | class ModelInfo(TypedDict):
20 | """Type definition for model information returned by Baseten API.
21 |
22 | :ivar id: The unique identifier of the model.
23 | :ivar name: The name of the model.
24 | """
25 |
26 | id: str
27 | name: str
28 |
29 |
30 | def _initialize_model() -> bool:
31 | """Initialize the Flow Judge model.
32 |
33 | :return: True if initialization is successful, False otherwise.
34 | :rtype: bool
35 | """
36 | logger.info("Initializing Flow Judge model")
37 | if _is_flowjudge_deployed():
38 | logger.info("Flow Judge already deployed")
39 | return True
40 |
41 | if not ensure_gpu():
42 | logger.error("BASETEN_GPU environment variable is required." " Set one of: H100 or A10G")
43 | return False
44 |
45 | truss_path: str = f"{os.path.dirname(os.path.realpath(os.path.abspath(__file__)))}/deployment"
46 | logger.debug(f"Truss path: {truss_path}")
47 | try:
48 | deployment: ModelDeployment = truss.push(
49 | truss_path, promote=True, trusted=True, environment=None
50 | )
51 | logger.debug("Waiting for deployment to become active")
52 | deployment.wait_for_active()
53 | logger.info("Flow Judge Baseten deployment successful")
54 |
55 | model_id = get_deployed_model_id()
56 | api_key = get_baseten_api_key()
57 | if model_id and api_key:
58 | has_updated_scale_down = sync_set_scale_down(
59 | scale_down_delay=120, api_key=api_key, model_id=model_id
60 | )
61 | if has_updated_scale_down:
62 | logger.info(
63 | "Successfully updated Baseten deployed model scale down delay to 2 mins."
64 | )
65 | else:
66 | logger.info(
67 | "Unable to update Baseten deployed model scale down delay period."
68 | " Continuing with default"
69 | )
70 |
71 | return True
72 | except ApiError as e:
73 | logger.error(
74 | "Flow Judge Baseten deployment failed. "
75 | f"Ensure that provided API key is correct and try again. {e.message}"
76 | )
77 | except ValueError as e:
78 | logger.error(
79 | "Flow Judge Baseten deployment failed. "
80 | f"Ensure that provided API key is correct and try again. {str(e)}"
81 | )
82 | return False
83 |
84 |
85 | def _get_models() -> list[ModelInfo] | None:
86 | """Fetch the list of models from Baseten API.
87 |
88 | :return: List of ModelInfo if successful, None otherwise.
89 | :rtype: Optional[List[ModelInfo]]
90 | """
91 | logger.debug("Fetching models from Baseten API")
92 | api_key: str | None = get_baseten_api_key()
93 | if not api_key:
94 | logger.warning("No API key available to fetch models")
95 | return None
96 |
97 | try:
98 | response: requests.Response = requests.get(
99 | "https://api.baseten.co/v1/models",
100 | headers={"Authorization": f"Api-Key {api_key}"},
101 | timeout=10,
102 | )
103 | response.raise_for_status()
104 | resp: dict[str, list[ModelInfo]] = response.json()
105 | models: list[ModelInfo] | None = resp.get("models")
106 | logger.debug(f"Fetched {len(models) if models else 0} models")
107 | return models
108 | except requests.RequestException as e:
109 | logger.error(f"Error fetching models: {e}")
110 | return None
111 |
112 |
113 | def _is_flowjudge_deployed() -> bool:
114 | """Check if Flow Judge is already deployed.
115 |
116 | :return: True if Flow Judge is deployed, False otherwise.
117 | :rtype: bool
118 | """
119 | logger.debug("Checking if Flow Judge is deployed")
120 | models: list[ModelInfo] | None = _get_models()
121 |
122 | if models is None:
123 | logger.warning("Unable to determine if Flow Judge is deployed")
124 | return False
125 |
126 | is_deployed: bool = any("Flow-Judge" in model["name"] for model in models)
127 | logger.debug(f"Flow Judge deployed: {is_deployed}")
128 | return is_deployed
129 |
130 |
131 | def get_deployed_model_id() -> str | None:
132 | """Get the ID of the deployed Flow Judge model.
133 |
134 | :return: The model ID if found, None otherwise.
135 | :rtype: Optional[str]
136 | """
137 | logger.debug("Getting deployed Flow Judge model ID")
138 | models: list[ModelInfo] | None = _get_models()
139 | if not models:
140 | logger.warning("No models found")
141 | return None
142 |
143 | for model in models:
144 | if "Flow-Judge" in model["name"]:
145 | logger.debug(f"Found Flow Judge model with ID: {model['id']}")
146 | return model["id"]
147 |
148 | logger.warning("Flow Judge model not found")
149 | return None
150 |
151 |
152 | def ensure_model_deployment() -> bool:
153 | """Ensure Flow Judge model deployment to Baseten.
154 |
155 | :return: True if deployment is successful, False otherwise.
156 | :rtype: bool
157 | """
158 | logger.info("Ensuring Flow Judge model deployment")
159 | if not ensure_baseten_authentication():
160 | logger.error("Baseten not authenticated, interrupting model deployment")
161 | return False
162 |
163 | ensure_baseten_webhook_secret(optional=True)
164 |
165 | return _initialize_model()
166 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/deployment/config.yaml:
--------------------------------------------------------------------------------
1 | base_image:
2 | image: baseten/truss-server-base:3.11-gpu-v0.9.0
3 | environment_variables:
4 | CUDA_HOME: /usr/local/cuda
5 | LD_LIBRARY_PATH: /usr/local/cuda/extras/CUPTI/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
6 | PATH: /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
7 | VLLM_ATTENTION_BACKEND: FLASH_ATTN
8 | external_package_dirs: []
9 | model_metadata:
10 | example_model_input:
11 | messages:
12 | - content: "# GOAL\nYour job is to evaluate a task carried out by an AI system\
13 | \ powered by a large language model.\n\nYou will be provided with the inputs\
14 | \ and output of the task, as well as the evaluation criteria and scoring rubric.\
15 | \ Your task is to evaluate the output of the AI system based on the evaluation\
16 | \ criteria and scoring rubric provided.\n\n# INPUT\nBelow are the inputs required\
17 | \ for performing the task:\n\n\nPlease read the technical issue\
18 | \ that the user is facing and help me create a detailed solution based on\
19 | \ the context provided.\n\n\n# Customer Issue: I'm having\
20 | \ trouble pushing large files to my Git repository. It says 'File size exceeds\
21 | \ GitHub's file size limit of 100 MB'.\n\n# Context: Git Large File Storage\
22 | \ (LFS) replaces large files such as audio samples, videos, datasets, and\
23 | \ graphics with text pointers inside Git, while storing the file contents\
24 | \ on a remote server like GitHub.com or GitHub Enterprise.\n\nTo use Git LFS,\
25 | \ you need to download and install the Git command line extension. For more\
26 | \ information, see 'Installing Git Large File Storage'.\n\nOnce Git LFS is\
27 | \ installed, you need to set up Git LFS for your user account by running:\n\
28 | \n$ git lfs install\n\n\n\n# OUTPUT\nBelow is the output\
29 | \ of the task:\n\n\n# EVALUATION CRITERIA AND SCORING RUBRIC\nHere are\
43 | \ the evaluation criteria and the rubric that you need to use for evaluating\
44 | \ the task:\n\nBased on the given context, evaluate how\
45 | \ consistent and faithful the generated response is to the context. The response\
46 | \ should not contain any hallucinated or fabricated information that is not\
47 | \ supported by the context.\n\n\n\n\
48 | - Score 1: The response is completely inconsistent with the provided context.\
49 | \ It contains significant amount of hallucinated or fabricated information\
50 | \ that directly contradicts or is not supported at all by the context.\n-\
51 | \ Score 2: The response is mostly inconsistent with the provided context.\
52 | \ While it may contain some information from the context, it introduces a\
53 | \ substantial amount of hallucinated or fabricated details that deviate from\
54 | \ the context.\n- Score 3: The response is somewhat consistent with the provided\
55 | \ context. It includes a mix of information from the context and some hallucinated\
56 | \ or fabricated details. The fabrications are minor and do not significantly\
57 | \ contradict the context.\n- Score 4: The response is mostly consistent with\
58 | \ the provided context. The vast majority of the content is supported by the\
59 | \ context, with only minor and inconsequential inconsistencies or fabrications,\
60 | \ if any.\n- Score 5: The response is completely consistent with and faithful\
61 | \ to the provided context. All details in the response are directly supported\
62 | \ by the context, without any hallucinated or fabricated information.\n\n\
63 | \n# INSTRUCTIONS FOR THE EVALUATION\n1. Understand the task and criteria:\
64 | \ Familiarize yourself with the task to be evaluated. Review the evaluation\
65 | \ criteria and scoring rubric to understand the different levels of performance\
66 | \ and the descriptions for each score.\n2. Review the inputs and output: Look\
67 | \ at the inputs provided for the task. Examine the output generated from completing\
68 | \ the task.\n3. Compare output to score descriptions: Compare the output against\
69 | \ the criteria and score descriptions in the scoring rubric. For each criterion,decide\
70 | \ which description best matches the output.\n4. After comparing the output\
71 | \ to the score descriptions, pay attention to the small details that might\
72 | \ impact the final score that you assign. Sometimes a small difference can\
73 | \ dictate the final score.\n5. Write verbal feedback justifying your evaluation\
74 | \ that includes a detailed rationale, referring to specific aspects of the\
75 | \ output and comparing them to the rubric.\n6. Assign a final score based\
76 | \ on the scoring rubric.\n\n## FORMAT FOR THE EVALUATION\n- Write the verbal\
77 | \ feedback inside tags without any additional surrounding text.\n\
78 | - Write the numeric score inside tags, without any additional surrounding\
79 | \ text and always after the feedback.\n\nPlease accurately evaluate the task.\
80 | \ Strictly adhere to the evaluation criteria and rubric."
81 | role: user
82 | openai_compatible: true
83 | repo_id: flowaicom/Flow-Judge-v0.1-AWQ
84 | vllm_config:
85 | max_model_len: 8192
86 | tensor_parallel_size: 1
87 | model_name: Flow-Judge-v0.1
88 | python_version: py311
89 | repo_id: flowaicom/Flow-Judge-v0.1-AWQ
90 | requirements:
91 | - vllm>=0.6.2
92 | - vllm-flash-attn
93 | resources:
94 | accelerator: A10G
95 | use_gpu: true
96 | runtime:
97 | predict_concurrency: 128
98 | secrets:
99 | hf_access_token: hf_xyz
100 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/deployment/model/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 |
4 | # Configure logging
5 | logging.basicConfig(
6 | level=logging.INFO,
7 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
8 | handlers=[logging.StreamHandler(sys.stdout)],
9 | )
10 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/deployment/model/helper.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import os
4 | import threading
5 |
6 | import httpx
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | DEFAULT_HEALTH_CHECK_INTERVAL = 5 # seconds
11 |
12 |
13 | def log_subprocess_output(process):
14 | """Process logs for the vLLM subprocess."""
15 | while True:
16 | output = process.stdout.readline()
17 | if output == "" and process.poll() is not None:
18 | break
19 | if output:
20 | logger.info(f"vLLM subprocess stdout: {output.strip()}")
21 | rc = process.poll()
22 | if rc != 0:
23 | for error_output in process.stderr.readlines():
24 | logger.error(f"vLLM subprocess stderr: {error_output.strip()}")
25 |
26 |
27 | async def monitor_vllm_server_health(vllm_server_url, health_check_interval):
28 | """Health check for the vLLM server."""
29 | assert vllm_server_url is not None, "vllm_server_url must not be None"
30 | try:
31 | async with httpx.AsyncClient() as client:
32 | while True:
33 | response = await client.get(f"{vllm_server_url}/health")
34 | if response.status_code != 200:
35 | raise RuntimeError("vLLM is unhealthy")
36 | await asyncio.sleep(health_check_interval)
37 | except Exception as e:
38 | logging.error(
39 | f"vLLM has gone into an unhealthy state due to error: {e}, restarting service now..."
40 | )
41 | os._exit(1)
42 |
43 |
44 | async def monitor_vllm_engine_health(vllm_engine, health_check_interval):
45 | """Health check for the vLLM engine."""
46 | assert vllm_engine is not None, "vllm_engine must not be None"
47 | try:
48 | while True:
49 | await vllm_engine.check_health()
50 | await asyncio.sleep(health_check_interval)
51 | except Exception as e:
52 | logging.error(
53 | f"vLLM has gone into an unhealthy state due to error: {e}, restarting service now..."
54 | )
55 | os._exit(1)
56 |
57 |
58 | def run_background_vllm_health_check(
59 | use_openai_compatible_server=False,
60 | health_check_interval=DEFAULT_HEALTH_CHECK_INTERVAL,
61 | vllm_engine=None,
62 | vllm_server_url=None,
63 | ):
64 | """Background process for vLLM health checks."""
65 | logger.info("Starting background health check loop")
66 | loop = asyncio.new_event_loop()
67 | if use_openai_compatible_server:
68 | loop.create_task(monitor_vllm_server_health(vllm_server_url, health_check_interval))
69 | else:
70 | loop.create_task(monitor_vllm_engine_health(vllm_engine, health_check_interval))
71 | thread = threading.Thread(target=loop.run_forever, daemon=True)
72 | thread.start()
73 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/deployment/model/model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import subprocess
4 | import threading
5 | import time
6 | import uuid
7 |
8 | import httpx
9 | from transformers import AutoTokenizer
10 | from vllm import SamplingParams
11 | from vllm.engine.arg_utils import AsyncEngineArgs
12 | from vllm.engine.async_llm_engine import AsyncLLMEngine
13 |
14 | # isort requires relative imports that are not accepted by Baseten
15 | from model.helper import log_subprocess_output, run_background_vllm_health_check # isort: skip
16 |
17 | MAX_LENGTH = 1024
18 | TEMPERATURE = 0.1
19 | TOP_P = 0.95
20 | TOP_K = 40
21 | DO_SAMPLE = True
22 | DEFAULT_STREAM = False
23 |
24 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | class Model:
30 | """Baseten model class for deployment."""
31 |
32 | # 25 minutes; the reason this would take this long is mostly if we download a large model
33 | MAX_FAILED_SECONDS = 1500
34 | HEALTH_CHECK_INTERVAL = 5 # seconds
35 |
36 | def __init__(self, **kwargs):
37 | """Initialize Baseten model deployment class."""
38 | self._config = kwargs["config"]
39 | self.model_id = None
40 | self.llm_engine = None
41 | self.model_args = None
42 | self.hf_secret_token = kwargs["secrets"].get("hf_access_token", None)
43 | self.openai_compatible = self._config["model_metadata"].get("openai_compatible", False)
44 | self.vllm_base_url = None
45 | os.environ["HF_TOKEN"] = self.hf_secret_token
46 |
47 | def _load_openai_compatible_model(self):
48 | """Load OpenAI compatible model."""
49 | self._client = httpx.AsyncClient(timeout=None)
50 | command = ["vllm", "serve", self._model_repo_id]
51 | for key, value in self._vllm_config.items():
52 | if value is True:
53 | command.append(f"--{key.replace('_', '-')}")
54 | elif value is False:
55 | continue
56 | else:
57 | command.append(f"--{key.replace('_', '-')}")
58 | command.append(str(value))
59 |
60 | logger.info(f"Starting openai compatible vLLM server with command: {command}")
61 |
62 | self._vllm_process = subprocess.Popen(
63 | command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
64 | )
65 |
66 | output_thread = threading.Thread(target=log_subprocess_output, args=(self._vllm_process,))
67 | output_thread.daemon = True
68 | output_thread.start()
69 |
70 | # Wait for 10 seconds and check if command fails
71 | time.sleep(10)
72 |
73 | if self._vllm_process.poll() is None:
74 | logger.info("Command to start vLLM server ran successfully")
75 | else:
76 | stdout, stderr = self._vllm_process.communicate()
77 | if self._vllm_process.returncode != 0:
78 | logger.error(f"Command failed with error: {stderr}")
79 | raise RuntimeError(
80 | f"Command failed with code {self._vllm_process.returncode}: {stderr}"
81 | )
82 |
83 | if self._vllm_config and "port" in self._vllm_config:
84 | self._vllm_port = self._vllm_config["port"]
85 | else:
86 | self._vllm_port = 8000
87 |
88 | self.vllm_base_url = f"http://localhost:{self._vllm_port}"
89 |
90 | # Polling to check if the server is up
91 | server_up = self._check_server_health()
92 |
93 | if not server_up:
94 | raise RuntimeError("Server failed to start within the maximum allowed time.")
95 |
96 | def _check_server_health(self):
97 | """Check if the server is up and running."""
98 | start_time = time.time()
99 | while time.time() - start_time < self.MAX_FAILED_SECONDS:
100 | try:
101 | response = httpx.get(f"{self.vllm_base_url}/health")
102 | logger.info(f"Checking server health: {response.status_code}")
103 | if response.status_code == 200:
104 | return True
105 | except httpx.RequestError as e:
106 | seconds_passed = int(time.time() - start_time)
107 | if seconds_passed % 10 == 0:
108 | logger.info(f"Server is starting for {seconds_passed} seconds: {e}")
109 | time.sleep(1) # Wait for 1 second before retrying
110 | return False
111 |
112 | def _load_non_openai_compatible_model(self):
113 | """Load non-OpenAI compatible model."""
114 | try:
115 | result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True)
116 | logger.info(result.stdout)
117 | except subprocess.CalledProcessError as e:
118 | logger.error(f"Command failed with code {e.returncode}: {e.stderr}")
119 |
120 | self.model_args = AsyncEngineArgs(model=self._model_repo_id, **self._vllm_config)
121 | self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args=self.model_args)
122 | self.tokenizer = AutoTokenizer.from_pretrained(self._model_repo_id)
123 |
124 | try:
125 | result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True)
126 | logger.info(result.stdout)
127 | except subprocess.CalledProcessError as e:
128 | logger.error(f"Command failed with code {e.returncode}: {e.stderr}")
129 |
130 | def load(self):
131 | """Load the model."""
132 | self._model_metadata = self._config["model_metadata"]
133 | self._model_repo_id = self._model_metadata["repo_id"]
134 | self._vllm_config = self._model_metadata["vllm_config"]
135 | if self._vllm_config is None:
136 | self._vllm_config = {}
137 | logger.info(f"main model: {self._model_repo_id}")
138 | logger.info(f"vllm config: {self._vllm_config}")
139 |
140 | if self.openai_compatible:
141 | self._load_openai_compatible_model()
142 | else:
143 | self._load_non_openai_compatible_model()
144 |
145 | try:
146 | run_background_vllm_health_check(
147 | self.openai_compatible,
148 | self.HEALTH_CHECK_INTERVAL,
149 | self.llm_engine,
150 | self.vllm_base_url,
151 | )
152 | except Exception as e:
153 | raise RuntimeError(f"Failed to start background health check: {e}") from e
154 |
155 | async def predict(self, model_input):
156 | """Generate output based on the input."""
157 | if "messages" not in model_input and "prompt" not in model_input:
158 | raise ValueError("Prompt or messages must be provided")
159 |
160 | stream = model_input.get("stream", False)
161 |
162 | if self.openai_compatible:
163 | return await self._predict_openai_compatible(model_input, stream)
164 | else:
165 | return await self._predict_non_openai_compatible(model_input, stream)
166 |
167 | async def _predict_openai_compatible(self, model_input, stream):
168 | """Generate output for OpenAI compatible model."""
169 | # if the key metrics: true is present, let's return the vLLM /metrics endpoint
170 | if model_input.get("metrics", False):
171 | response = await self._client.get(f"{self.vllm_base_url}/metrics")
172 | return response.text
173 |
174 | # convenience for Baseten bridge
175 | if "model" not in model_input and self._model_repo_id:
176 | logger.info(
177 | f"model_input missing model due to Baseten bridge, using {self._model_repo_id}"
178 | )
179 | model_input["model"] = self._model_repo_id
180 |
181 | if stream:
182 |
183 | async def generator():
184 | async with self._client.stream(
185 | "POST",
186 | f"{self.vllm_base_url}/v1/chat/completions",
187 | json=model_input,
188 | ) as response:
189 | async for chunk in response.aiter_bytes():
190 | if chunk:
191 | yield chunk
192 |
193 | return generator()
194 | else:
195 | response = await self._client.post(
196 | f"{self.vllm_base_url}/v1/chat/completions",
197 | json=model_input,
198 | )
199 |
200 | return response.json()
201 |
202 | async def _predict_non_openai_compatible(self, model_input, stream):
203 | """Generate output for non-OpenAI compatible model."""
204 | # SamplingParams does not take/use argument 'model'
205 | if "model" in model_input:
206 | model_input.pop("model")
207 | if "prompt" in model_input:
208 | prompt = model_input.pop("prompt")
209 | sampling_params = SamplingParams(**model_input)
210 | idx = str(uuid.uuid4().hex)
211 | messages = [
212 | {"role": "user", "content": prompt},
213 | ]
214 | # templatize the input to the model
215 | input = self.tokenizer.apply_chat_template(
216 | messages, tokenize=False, add_generation_prompt=True
217 | )
218 | elif "messages" in model_input:
219 | messages = model_input.pop("messages")
220 | sampling_params = SamplingParams(**model_input)
221 | idx = str(uuid.uuid4().hex)
222 | # templatize the input to the model
223 | input = self.tokenizer.apply_chat_template(
224 | messages,
225 | tokenize=False,
226 | )
227 | logger.info(f"Using SamplingParams: {sampling_params}")
228 | # since we accept any valid vllm sampling parameters, we can just pass it through
229 | vllm_generator = self.llm_engine.generate(input, sampling_params, idx)
230 |
231 | async def generator():
232 | full_text = ""
233 | async for output in vllm_generator:
234 | text = output.outputs[0].text
235 | delta = text[len(full_text) :]
236 | full_text = text
237 | yield delta
238 |
239 | if stream:
240 | return generator()
241 | else:
242 | full_text = ""
243 | async for delta in generator():
244 | full_text += delta
245 | return {"text": full_text}
246 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/errors.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | from pydantic import BaseModel, Field, field_validator
4 |
5 |
6 | class FlowJudgeError(BaseModel):
7 | """Represents an error encountered during the Flow Judge evaluation process.
8 |
9 | This class encapsulates detailed error information, including the type of error,
10 | the specific message, the request ID that caused the error, and other metadata.
11 |
12 | Attributes:
13 | error_type (str): The type of error encountered (e.g., "TimeoutError").
14 | error_message (str): A detailed description of the error.
15 | request_id (str): The ID of the request that caused the error.
16 | timestamp (datetime): The time when the error occurred.
17 | retry_count (int): The number of retry attempts made before the error was raised.
18 | raw_response (Optional[str]): The raw response from Baseten or proxy, if available.
19 |
20 | Note:
21 | This class is used for both logging and error handling. Ensure that sensitive
22 | information is not included in the error_message or raw_response fields.
23 | """
24 |
25 | error_type: str = Field(..., description="Type of the error encountered")
26 | error_message: str = Field(..., description="Detailed error message")
27 | request_id: str | None = Field(
28 | default=None, description="ID of the request that caused the error"
29 | )
30 | timestamp: datetime = Field(
31 | default_factory=datetime.now, description="Time when the error occurred"
32 | )
33 | retry_count: int = Field(default=0, description="Number of retry attempts made")
34 | raw_response: str | None = Field(
35 | None, description="Raw response from Baseten or proxy, if available"
36 | )
37 |
38 | @field_validator("error_type", "error_message")
39 | @classmethod
40 | def check_non_empty_string(cls, v):
41 | """Placeholder."""
42 | if not v.strip():
43 | raise ValueError("Field must not be empty or just whitespace")
44 | return v
45 |
46 |
47 | class BasetenAPIError(Exception):
48 | """Base exception for Baseten API errors."""
49 |
50 | pass
51 |
52 |
53 | class BasetenRequestError(BasetenAPIError):
54 | """Exception for request-related errors."""
55 |
56 | pass
57 |
58 |
59 | class BasetenResponseError(BasetenAPIError):
60 | """Exception for response-related errors."""
61 |
62 | pass
63 |
64 |
65 | class BasetenRateLimitError(BasetenAPIError):
66 | """Exception for rate limit errors."""
67 |
68 | pass
69 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/gpu.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from enum import Enum
4 | from typing import Any
5 |
6 | import yaml
7 |
8 | from .util import is_interactive
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class ModelByGPU(Enum):
14 | """Select the appropriate model based on GPU."""
15 |
16 | H100_40GB = "flowaicom/Flow-Judge-v0.1-FP8"
17 | A10G = "flowaicom/Flow-Judge-v0.1-AWQ"
18 |
19 |
20 | def _get_gpu_key() -> str | None:
21 | """Fetches the BASETEN_GPU environment variable.
22 |
23 | :returns: The value of the variable: one of H100, A10G
24 | :rtype: str | None
25 | """
26 | gpu: str | None = os.environ.get("BASETEN_GPU")
27 |
28 | if gpu:
29 | if gpu.lower() not in ["h100", "a10g"]:
30 | raise ValueError("BASETEN_GPU option is incorrect." "Possible options: H100, A10G")
31 |
32 | gpu = "H100_40GB" if gpu.lower() == "h100" else gpu
33 |
34 | return gpu
35 |
36 |
37 | def _has_gpu_key() -> bool:
38 | """Verifies that the BASETEN_GPU environment variable is present.
39 |
40 | :returns: True if found, False otherwise.
41 | :rtype: bool
42 | """
43 | gpu: str | None = _get_gpu_key()
44 |
45 | return True if gpu else False
46 |
47 |
48 | def _update_config() -> bool:
49 | """Update the config.yaml file with the GPU & flow-judge model id.
50 |
51 | :returns: True if successfully updated, False if not.
52 | :rtype: bool
53 | """
54 | gpu: str = _get_gpu_key()
55 |
56 | if not gpu:
57 | return False
58 |
59 | config_path: str = os.path.join(
60 | os.path.dirname(os.path.realpath(os.path.abspath(__file__))), "deployment", "config.yaml"
61 | )
62 |
63 | try:
64 | with open(config_path) as file:
65 | data: dict[str, Any] = yaml.safe_load(file)
66 |
67 | data["resources"]["accelerator"] = ModelByGPU[gpu.upper()].name
68 | data["repo_id"] = ModelByGPU[gpu.upper()].value
69 | data["model_metadata"]["repo_id"] = ModelByGPU[gpu.upper()].value
70 |
71 | with open(config_path, "w") as file:
72 | yaml.safe_dump(data, file, default_flow_style=False)
73 |
74 | return True
75 |
76 | except FileNotFoundError:
77 | logger.error(f"Baseten config.yaml file not found on path {config_path}")
78 | return False
79 | except yaml.YAMLError as e:
80 | logger.error(f"Error: Failed to parse the Baseten config file. {e}")
81 | return False
82 | except Exception as e:
83 | logger.error(f"An unexpected error occurred with Baseten config file update: {e}")
84 | return False
85 |
86 |
87 | def ensure_gpu() -> bool:
88 | """Enable GPU selection for FlowJudge model deployment.
89 |
90 | :return: True if successfully updated, False otherwise
91 | :rtype: bool
92 | """
93 | if _has_gpu_key():
94 | return _update_config()
95 |
96 | if is_interactive():
97 | print("What GPU on Baseten should we deploy the FlowJudge model to?")
98 | print(" ➡️ H100")
99 | print(" ➡️ A10G: default")
100 | print("Would you like to switch your deployment to H100?")
101 | print("y/n?\n")
102 |
103 | else:
104 | logger.info("Non-interactive environment detected")
105 | print("What GPU on Baseten should we deploy the FlowJudge model to?")
106 | print(" ➡️ H100")
107 | print(" ➡️ A10G")
108 | print("Please set the environment variable to the appropriate value:")
109 | print("```")
110 | print('os.environ["BASETEN_GPU"] = "<>"')
111 | print("```")
112 | return False
113 |
114 | logger.info("Prompting for GPU in interactive environment")
115 | while True:
116 | upgrade: str = input()
117 | if not upgrade:
118 | logger.warning("Empty option entered")
119 | print("Input is empty, try again.")
120 | continue
121 |
122 | if upgrade.lower() not in ["yes", "y", "n", "no"]:
123 | logger.warning("Incorrect option selected")
124 | print("Incorrect option, select one from y/n")
125 | continue
126 |
127 | if upgrade.lower() in ["yes", "y"]:
128 | os.environ["BASETEN_GPU"] = ModelByGPU.H100_40GB.name
129 | return _update_config()
130 |
131 | if upgrade.lower() in ["n", "no"]:
132 | os.environ["BASETEN_GPU"] = ModelByGPU.A10G.name
133 | return _update_config()
134 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/management.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import aiohttp
4 | import requests
5 | import structlog
6 |
7 | logger = structlog.get_logger(__name__)
8 |
9 |
10 | def _get_management_base_url(model_id: str) -> str:
11 | """Get the base URL for the Management API.
12 |
13 | :param model_id: The ID of the deployed model.
14 | :returns: The URL for the management endpoints
15 | """
16 | return f"https://api.baseten.co/v1/models/{model_id}/deployments/production"
17 |
18 |
19 | def sync_set_scale_down(scale_down_delay: int, api_key: str, model_id: str) -> bool:
20 | """Syncronous update to the cooldown period for the deployed model.
21 |
22 | :param scale_down_delay: The cooldown period in seconds.
23 | :param api_key: The Baseten API key.
24 | :param model_id: The ID of the deployed model on Baseten.
25 | :returns bool: True if successful, False otherwise.
26 | """
27 | try:
28 | resp = requests.patch(
29 | f"{_get_management_base_url(model_id)}/autoscaling_settings",
30 | headers={"Authorization": f"Api-Key {api_key}"},
31 | json={"scale_down_delay": scale_down_delay},
32 | )
33 |
34 | re = resp.json()
35 | if "status" in re:
36 | return True if re["status"] in ["ACCEPTED", "UNCHANGED", "QUEUED"] else False
37 | except Exception as e:
38 | logger.error("Unexpected error occurred with Baseten scale down delay request" f" {e}")
39 | return False
40 |
41 |
42 | async def set_scale_down_delay(scale_down_delay: int, api_key: str, model_id: str) -> bool:
43 | """Dynamically updates the cooldown period for the deployed model.
44 |
45 | :param scale_down_delay: The cooldown period in seconds.
46 | :param api_key: The Baseten API key.
47 | :param model_id: The ID of the deployed model on Baseten.
48 | :returns bool: True if successful, False otherwise.
49 | """
50 | url = f"{_get_management_base_url(model_id)}/autoscaling_settings"
51 | try:
52 | async with aiohttp.ClientSession() as session:
53 | async with session.patch(
54 | url=url,
55 | headers={"Authorization": f"Api-Key {api_key}"},
56 | json={"scale_down_delay": scale_down_delay},
57 | ) as response:
58 | if response.status != 200:
59 | logger.warning(
60 | "Unable to update Baseten scale down delay attribute."
61 | f"Request failed with status code {response.status}"
62 | )
63 | return False
64 |
65 | resp = await response.json()
66 | if "status" in resp:
67 | return True if resp["status"] in ["ACCEPTED", "UNCHANGED", "QUEUED"] else False
68 | except aiohttp.ClientError as e:
69 | logger.warning("Network error with Baseten scale_down_delay" f" {e}")
70 | return False
71 | except Exception as e:
72 | logger.warning("Unexpected error occurred with Baseten scale down delay request" f" {e}")
73 | return False
74 |
75 |
76 | async def wake_deployment(model_id: str, api_key: str) -> bool:
77 | """Activates the Baseten model.
78 |
79 | :param model_id: The ID of the deployed model.
80 | :returns: True if success, False if failed.
81 | :rtype: bool
82 | """
83 | url = f"https://model-{model_id}.api.baseten.co/production/wake"
84 | try:
85 | async with aiohttp.ClientSession() as session:
86 | async with session.post(
87 | url=url, headers={"Authorization": f"Api-Key {api_key}"}, json={}
88 | ) as response:
89 | if response.status != 202:
90 | logger.warning(
91 | "Unable to activate Baseten model."
92 | f"Request failed with status code {response.status}"
93 | )
94 | return False
95 |
96 | return True
97 | except aiohttp.ClientError as e:
98 | logger.warning("Network error with Baseten model activation." f" {e}")
99 | return False
100 | except Exception as e:
101 | logger.error("Unexpected error occurred with Baseten model activation." f" {e}")
102 | return False
103 |
104 |
105 | async def get_production_deployment_status(model_id: str, api_key: str) -> str | None:
106 | """Get model production deployment_id by it's model_id.
107 |
108 | :param model_id: The ID of the deployed model.
109 | :returns: The deployment_id of the production model.
110 | :rtype: str
111 | """
112 | try:
113 | async with aiohttp.ClientSession() as session:
114 | async with session.get(
115 | _get_management_base_url(model_id),
116 | headers={"Authorization": f"Api-Key {api_key}"},
117 | ) as response:
118 | if response.status != 200:
119 | logger.warning(
120 | "Unable to get model deployment details"
121 | f"Request failed with status {response.status}"
122 | )
123 | return None
124 |
125 | re = await response.json()
126 | return re["status"]
127 | except (json.JSONDecodeError, KeyError, IndexError) as e:
128 | logger.warning("Unable to parse response for Model deployment info request." f" {e}")
129 | return None
130 | except aiohttp.ClientError as e:
131 | logger.warning("Network error with Baseten model deployment information." f" {e}")
132 | return None
133 | except Exception as e:
134 | logger.error(
135 | "Unexpected error occurred with Baseten model deployment info request." f" {e}"
136 | )
137 | return None
138 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/token_bucket.py:
--------------------------------------------------------------------------------
1 | import time
2 | from dataclasses import dataclass, field
3 |
4 |
5 | @dataclass
6 | class TokenBucket:
7 | """Implements a token bucket algorithm for rate limiting.
8 |
9 | This class manages a token bucket with a specified capacity and fill rate,
10 | allowing for controlled consumption of tokens over time.
11 |
12 | Attributes:
13 | tokens (float): Current number of tokens in the bucket.
14 | fill_rate (float): Rate at which tokens are added to the bucket (tokens per second).
15 | capacity (float): Maximum number of tokens the bucket can hold.
16 | last_update (float): Timestamp of the last token update.
17 |
18 | Note:
19 | This implementation is not thread-safe. If used in a multi-threaded environment,
20 | external synchronization mechanisms should be applied.
21 | """
22 |
23 | tokens: float
24 | fill_rate: float
25 | capacity: float
26 | last_update: float = field(default_factory=time.time)
27 |
28 | def consume(self, tokens: int = 1) -> bool:
29 | """Attempt to consume tokens from the bucket.
30 |
31 | Args:
32 | tokens (int): Number of tokens to consume. Defaults to 1.
33 |
34 | Returns:
35 | bool: True if tokens were successfully consumed, False otherwise.
36 |
37 | Note:
38 | This method updates the token count based on the time elapsed since
39 | the last update, then attempts to consume the requested number of tokens.
40 | """
41 | now = time.time()
42 | self.tokens = min(self.capacity, self.tokens + self.fill_rate * (now - self.last_update))
43 | self.last_update = now
44 | if self.tokens >= tokens:
45 | self.tokens -= tokens
46 | return True
47 | return False
48 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/util.py:
--------------------------------------------------------------------------------
1 | def is_interactive() -> bool:
2 | """Check if the current environment is interactive.
3 |
4 | :return: True if the environment is interactive, False otherwise.
5 | :rtype: bool
6 | """
7 | import sys
8 |
9 | return sys.__stdin__.isatty()
10 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/validation.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import hmac
3 | import logging
4 | import os
5 | from datetime import datetime, timezone
6 |
7 | from pydantic import BaseModel, ConfigDict, JsonValue
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | TIMESTAMP_TOLERANCE_SECONDS = 300
12 |
13 |
14 | class AsyncPredictResult(BaseModel):
15 | """Baseten completion response format."""
16 |
17 | model_config = ConfigDict(protected_namespaces=(), extra="allow")
18 |
19 | request_id: str
20 | model_id: str
21 | deployment_id: str
22 | type: str
23 | time: datetime
24 | data: JsonValue
25 | errors: list[dict]
26 |
27 |
28 | def validate_baseten_signature(result, actual_signatures) -> bool:
29 | """Webhook signature validation from baseten."""
30 | try:
31 | webhook_secret = os.environ["BASETEN_WEBHOOK_SECRET"]
32 | except Exception as e:
33 | logger.error(
34 | "Baseten webhook secret is not in the environment."
35 | "Unable to validate baseten signature for batched requests."
36 | "Set the BASETEN_WEBHOOK_SECRET env variable to proceed."
37 | f"{e}"
38 | )
39 | return False
40 |
41 | async_predict_result = AsyncPredictResult(**result)
42 |
43 | if (
44 | datetime.now(timezone.utc) - async_predict_result.time
45 | ).total_seconds() > TIMESTAMP_TOLERANCE_SECONDS:
46 | logger.error(
47 | f"Async predict result was received after {TIMESTAMP_TOLERANCE_SECONDS} seconds"
48 | "and is considered stale, Baseten signature was not validated."
49 | )
50 | return False
51 |
52 | for actual_signature in actual_signatures.replace("v1=", "").split(","):
53 | expected_signature = hmac.digest(
54 | webhook_secret.encode("utf-8"),
55 | async_predict_result.model_dump_json().encode("utf-8"),
56 | hashlib.sha256,
57 | ).hex()
58 |
59 | if hmac.compare_digest(expected_signature, actual_signature):
60 | logger.info("Baseten signature is valid!")
61 | return True
62 |
63 | logger.error(
64 | "Baseten signature is not valid. Ensure your webhook secrets are properly configured."
65 | )
66 | return False
67 |
--------------------------------------------------------------------------------
/flow_judge/models/adapters/baseten/webhook.py:
--------------------------------------------------------------------------------
1 | import getpass
2 | import logging
3 | import os
4 | import re
5 |
6 | from .util import is_interactive
7 |
8 | logger: logging.Logger = logging.getLogger(__name__)
9 |
10 | _stored_secret_path: str = "~/.config/flow-judge/baseten_webhook_secret"
11 | _stored_skip_file: str = "~/.config/flow-judge/baseten_whsec_skipped"
12 |
13 |
14 | def _is_valid_secret(secret: str) -> bool:
15 | """Validate the Baseten webhook secret.
16 |
17 | :param secret: The secret to validate
18 | :return: True if the secret is valid, False otherwise
19 | """
20 | logger.debug("Validating webhook secret")
21 | return bool(re.match(r"^whsec_[a-zA-Z0-9]{40}$", secret))
22 |
23 |
24 | def _save_webhook_secret(secret: str) -> None:
25 | """Save the webhook secret to a file.
26 |
27 | :param secret: The secret to save
28 | """
29 | logger.debug(f"Saving webhook secret to {_stored_secret_path}")
30 | try:
31 | p: str = os.path.expanduser(_stored_secret_path)
32 | os.makedirs(os.path.dirname(p), exist_ok=True)
33 | with open(p, "w") as f:
34 | f.write(secret)
35 | logger.info("Webhook secret saved successfully")
36 | except OSError as e:
37 | logger.error(f"Failed to save webhook secret: {e}")
38 |
39 |
40 | def _get_stored_secret() -> str | None:
41 | """Retrieve the stored webhook secret.
42 |
43 | :return: The stored secret if available, None otherwise
44 | """
45 | logger.debug(f"Attempting to retrieve stored webhook secret from {_stored_secret_path}")
46 | try:
47 | p: str = os.path.expanduser(_stored_secret_path)
48 | if not os.path.exists(p):
49 | logger.info("No stored webhook secret found")
50 | return None
51 | with open(p) as f:
52 | secret: str = f.read().strip()
53 | logger.info("Stored webhook secret retrieved successfully")
54 | return secret
55 | except OSError as e:
56 | logger.error(f"Failed to retrieve stored webhook secret: {e}")
57 | return None
58 |
59 |
60 | def _save_skip_file() -> None:
61 | """Creates a flag file indicating that user requested to skip secret webhook input."""
62 | os.makedirs(os.path.expanduser(os.path.dirname(_stored_skip_file)), exist_ok=True)
63 | open(os.path.expanduser(_stored_skip_file), "a").close()
64 |
65 |
66 | def _handle_skip() -> bool:
67 | """Handles a possible user request to skip question about webhook secret.
68 |
69 | Checks for existence of BASETEN_SKIP_WEBHOOK_SECRET env variable, indicating that
70 | skip file should be saved. Skip file existence is a flag telling us to skip the
71 | question about webhook secret.
72 |
73 | :return: True if the webhook secret prompt should be skipped, False otherwise
74 | :rtype: bool
75 | """
76 | if os.environ.get("BASETEN_SKIP_WEBHOOK_SECRET"):
77 | logger.info(
78 | "User requested to skip the webhook secret prompt, skipping and saving skip file. "
79 | "Remove ~/.config/flow-judge/baseten_whsec_skipped to restore the prompt."
80 | )
81 | _save_skip_file()
82 | return True
83 |
84 | if os.path.exists(os.path.expanduser(_stored_skip_file)):
85 | logger.info(
86 | "Webhook secret not required and user skipped it before, skipping. "
87 | "Remove ~/.config/flow-judge/baseten_whsec_skipped to restore the prompt."
88 | )
89 | return True
90 |
91 | return False
92 |
93 |
94 | def _handle_env_variable_input() -> bool:
95 | """Checks if webhook secret was provided in environment variable.
96 |
97 | :return: True if the webhook secret was provided, False otherwise
98 | :rtype: bool
99 | """
100 | env_secret: str | None = os.getenv("BASETEN_WEBHOOK_SECRET")
101 | if not env_secret:
102 | return False
103 |
104 | logger.info("Found BASETEN_WEBHOOK_SECRET in environment variables")
105 | if _is_valid_secret(env_secret):
106 | logger.info("Environment variable contains a valid webhook secret")
107 | _save_webhook_secret(env_secret)
108 | else:
109 | logger.warning("Probably invalid BASETEN_WEBHOOK_SECRET in environment variable")
110 | return True
111 |
112 |
113 | def _handle_stored_secret() -> bool:
114 | """Checks if webhook secret was previously provided and stored in file.
115 |
116 | :return: True if the webhook secret was provided, False otherwise
117 | :rtype: bool
118 | """
119 | stored_secret: str | None = _get_stored_secret()
120 | if not stored_secret:
121 | return False
122 |
123 | logger.info("Found stored webhook secret")
124 | if _is_valid_secret(stored_secret):
125 | logger.info("Stored webhook secret is valid")
126 | os.environ["BASETEN_WEBHOOK_SECRET"] = stored_secret
127 | else:
128 | logger.warning("Stored webhook secret is probably invalid")
129 | return True
130 |
131 |
132 | def _prompt_interactively(optional: bool) -> str | None:
133 | """Asks user to enter webhook secret or skip if optional.
134 |
135 | :param optional: True if the input can be skipped by user
136 | :return: User-provided webhook secret or None if skipped
137 | :rtype: str | None
138 | """
139 | while True:
140 | secret: str = getpass.getpass(
141 | "Baseten webhook secret (hidden): "
142 | if not optional
143 | else "Baseten webhook secret (hidden; leave empty to skip): "
144 | )
145 |
146 | if optional and not secret:
147 | logger.info("User skipped optional webhook secret input")
148 | return None
149 |
150 | if not secret:
151 | logger.warning("Empty input received")
152 | print("Input is empty, please try again.")
153 | continue
154 |
155 | return secret
156 |
157 |
158 | def _print_general_prompt(optional: bool) -> None:
159 | """Prints general information/prompt about webhook secret input requirement.
160 |
161 | :param optional: Whether to print the information that the input is optional
162 | """
163 | print(
164 | "To run Flow Judge remotely with Baseten and enable async execution, "
165 | "you need to create and configure a webhook secret.\n"
166 | "➡️ Creating a webhook secret: https://docs.baseten.co/invoke/async-secure\n\n"
167 | "The webhook secret is used to validate that the webhook responses originated "
168 | "from Baseten. "
169 | f"It will be stored in {_stored_secret_path} for later use.\n"
170 | "For your convenience, Baseten responses are forwarded to you using Flow AI proxy.\n"
171 | "Explore what this means and the alternatives here: "
172 | "https://github.com/flowaicom/flow-judge/\n"
173 | )
174 | if optional:
175 | print("\033[1mOptional. This is only required if you plan async execution.\033[0m")
176 |
177 |
178 | def _print_noninteractive_prompt(optional: bool) -> None:
179 | """Prints a prompt for non interactive environments.
180 |
181 | :param optional: Whether to print the information that the input is optional
182 | """
183 | logger.info("Non-interactive environment detected")
184 | print(
185 | "Set the Baseten webhook secret in the BASETEN_WEBHOOK_SECRET "
186 | "environment variable and run again:\n"
187 | 'os.environ["BASETEN_WEBHOOK_SECRET"] = "«your webhook secret»"\n'
188 | "The secret should start with 'whsec_' followed by 40 alphanumeric characters.\n"
189 | )
190 | if optional:
191 | print(
192 | "If you don't want to see this message in the future set the "
193 | "BASETEN_SKIP_WEBHOOK_SECRET environment variable to any non-empty value:\n"
194 | 'os.environ["BASETEN_SKIP_WEBHOOK_SECRET"]="true"\n'
195 | )
196 |
197 |
198 | def ensure_baseten_webhook_secret(optional: bool = False) -> bool:
199 | """Ensure a Baseten webhook secret is available.
200 |
201 | This function checks for a webhook secret in the following order:
202 | 1. Environment variable
203 | 2. Stored secret file
204 | 3. User input (in interactive mode)
205 |
206 | :param optional: Whether allow the user to omit the webhook secret input
207 | :return: True if a secret is available (valid or not), False if no secret could be obtained
208 | """
209 | # If input optional, check if user requested to skip the prompt
210 | if optional and _handle_skip():
211 | return False
212 |
213 | logger.info("Ensuring Baseten webhook secret")
214 |
215 | # Check environment variable
216 | if _handle_env_variable_input():
217 | return True
218 |
219 | # Check stored secret
220 | if _handle_stored_secret():
221 | return True
222 |
223 | logger.info("No existing webhook secret found, prompting user")
224 | _print_general_prompt(optional)
225 |
226 | # In non-interactive environments we return early because their inputs (env vars)
227 | # are processed at the very beginning of this function.
228 | if not is_interactive():
229 | _print_noninteractive_prompt(optional)
230 | return False
231 |
232 | logger.info("Prompting user for webhook secret interactively")
233 | secret = _prompt_interactively(optional)
234 |
235 | # Secret not provided which means user asked to skip the prompt
236 | if secret is None:
237 | print("(skipped)")
238 | _save_skip_file()
239 | return False
240 |
241 | if not _is_valid_secret(secret):
242 | logger.warning("Invalid webhook secret provided")
243 | print(
244 | "Warning: The provided webhook secret is probably invalid. "
245 | "It should start with 'whsec_' followed by 40 alphanumeric characters. "
246 | "Proceeding anyway."
247 | )
248 |
249 | _save_webhook_secret(secret)
250 | os.environ["BASETEN_WEBHOOK_SECRET"] = secret
251 |
252 | return True
253 |
--------------------------------------------------------------------------------
/flow_judge/models/baseten.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections.abc import Coroutine
3 | from typing import Any
4 |
5 | from .adapters.baseten.adapter import AsyncBasetenAPIAdapter, BaseAPIAdapter, BasetenAPIAdapter
6 | from .adapters.baseten.data_io import BatchResult
7 | from .adapters.baseten.deploy import ensure_model_deployment, get_deployed_model_id
8 | from .adapters.baseten.webhook import ensure_baseten_webhook_secret
9 | from .common import (
10 | AsyncBaseFlowJudgeModel,
11 | BaseFlowJudgeModel,
12 | ModelConfig,
13 | ModelType,
14 | VllmGenerationParams,
15 | )
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class BasetenModelConfig(ModelConfig):
21 | """Model config for the Baseten model class."""
22 |
23 | def __init__(
24 | self,
25 | generation_params: VllmGenerationParams,
26 | exec_async: bool = False,
27 | webhook_proxy_url: str | None = None,
28 | async_batch_size: int = 128,
29 | **kwargs: Any,
30 | ):
31 | """Initialize the Baseten model config.
32 |
33 | :param generation_params: VllmGenerationParams for text generation.
34 | :param exec_async: Whether to use async execution.
35 | :param webhook_proxy_url: Webhook URL for Baseten async execution.
36 | :param async_batch_size: Batch size for concurrent requests in async mode.
37 | :raises ValueError: If any input parameters are invalid.
38 | """
39 | model_id = kwargs.pop("_model_id", None)
40 | if model_id is None:
41 | model_id = get_deployed_model_id()
42 | if model_id is None:
43 | raise ValueError("Unable to retrieve Baseten's deployed model id.")
44 |
45 | model_type = ModelType.BASETEN_VLLM_ASYNC if exec_async else ModelType.BASETEN_VLLM
46 |
47 | if not isinstance(generation_params, VllmGenerationParams):
48 | raise ValueError("generation_params must be an instance of VllmGenerationParams")
49 | if async_batch_size <= 0:
50 | raise ValueError(f"async_batch_size must be > 0, got {async_batch_size}")
51 | if exec_async and webhook_proxy_url is None:
52 | raise ValueError("webhook_proxy_url is required for async execution")
53 |
54 | super().__init__(model_id, model_type, generation_params, **kwargs)
55 | self.webhook_proxy_url = webhook_proxy_url
56 | self.exec_async = exec_async
57 | self.async_batch_size = async_batch_size
58 |
59 |
60 | class Baseten(BaseFlowJudgeModel, AsyncBaseFlowJudgeModel):
61 | """Combined FlowJudge Model class for Baseten sync and webhook async operations."""
62 |
63 | def __init__(
64 | self,
65 | api_adapter: BaseAPIAdapter | None = None,
66 | webhook_proxy_url: str | None = None,
67 | exec_async: bool = False,
68 | async_batch_size: int = 128,
69 | generation_params: dict[str, Any] | None = None,
70 | **kwargs: Any,
71 | ):
72 | """Initialize the Baseten Model class.
73 |
74 | :param api_adapter: API handling class for Baseten requests.
75 | :param webhook_proxy_url: The webhook url for the proxy when exec_async is True.
76 | :param exec_async: Whether to use async webhook execution.
77 | :param async_batch_size: Batch size for concurrent requests to Baseten in async.
78 | :param generation_params: Dictionary of parameters for text generation.
79 | :raises BasetenError: If Baseten deployment or model ID retrieval fails.
80 | :raises ValueError: If input parameters are invalid.
81 | """
82 | if exec_async and not webhook_proxy_url:
83 | raise ValueError("webhook_proxy_url is required for async Baseten execution.")
84 |
85 | if async_batch_size < 1:
86 | raise ValueError("async_batch_size must be greater than 0.")
87 |
88 | # Check for a custom Baseten model ID provided by the user
89 | # This allows for using a specific model deployment instead of the default
90 | model_id = kwargs.pop("_model_id", None)
91 | if model_id is None:
92 | # If no custom ID, attempt to retrieve the default deployed Flow Judge model ID
93 | # This covers the case where the model is already deployed but not specified
94 | model_id = get_deployed_model_id()
95 | if model_id is None:
96 | # No deployed model found, so try to deploy the default Flow Judge model
97 | # This step handles first-time usage or cases where the model was removed
98 | if not ensure_model_deployment():
99 | # Deployment failed, which could be due to API key issues,
100 | # network problems, or Baseten service unavailability
101 | raise BasetenError(
102 | status_code=1,
103 | message=(
104 | "Baseten deployment is not available."
105 | " This could be due to API key issues, "
106 | "network problems, or Baseten service "
107 | "unavailability. Please check your "
108 | "API key, network connection, and Baseten service status."
109 | ),
110 | )
111 | # Deployment succeeded, so attempt to retrieve the model ID again
112 | # This should now succeed unless there's an unexpected issue
113 | model_id = get_deployed_model_id()
114 | if model_id is None:
115 | # If we still can't get the model ID, it indicates a deeper problem
116 | # This could be due to a bug in the deployment process or Baseten API changes
117 | raise BasetenError(
118 | status_code=2,
119 | message=(
120 | "Unable to retrieve Baseten's deployed model id. "
121 | "Please ensure the model is deployed or provide a custom '_model_id'."
122 | ),
123 | )
124 |
125 | if exec_async and not ensure_baseten_webhook_secret():
126 | raise BasetenError(
127 | status_code=4,
128 | message=(
129 | "Unable to retrieve Baseten's webhook secret. "
130 | "Please ensure the webhook secret is provided in "
131 | "BASETEN_WEBHOOK_SECRET environment variable."
132 | ),
133 | )
134 |
135 | if api_adapter is not None and not isinstance(
136 | api_adapter, (BasetenAPIAdapter, AsyncBasetenAPIAdapter)
137 | ):
138 | raise BasetenError(
139 | status_code=3,
140 | message="Incompatible API adapter. Use BasetenAPIAdapter or AsyncBasetenAPIAdapter",
141 | )
142 |
143 | self.api_adapter = api_adapter or (
144 | AsyncBasetenAPIAdapter(model_id, webhook_proxy_url, async_batch_size)
145 | if exec_async
146 | else BasetenAPIAdapter(model_id)
147 | )
148 |
149 | generation_params = VllmGenerationParams(**(generation_params or {}))
150 | config = BasetenModelConfig(
151 | generation_params=generation_params,
152 | exec_async=exec_async,
153 | webhook_proxy_url=webhook_proxy_url,
154 | async_batch_size=async_batch_size,
155 | _model_id=model_id,
156 | )
157 | self.config = config
158 |
159 | super().__init__(model_id, config.model_type, config.generation_params, **kwargs)
160 |
161 | logger.info("Successfully initialized Baseten!")
162 |
163 | def _format_conversation(self, prompt: str) -> list[dict[str, Any]]:
164 | return [{"role": "user", "content": prompt.strip()}]
165 |
166 | def _generate(self, prompt: str) -> str:
167 | logger.info("Initiating single Baseten request")
168 |
169 | conversation = self._format_conversation(prompt)
170 | return self.api_adapter._fetch_response(conversation)
171 |
172 | def _batch_generate(
173 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any
174 | ) -> list[str]:
175 | logger.info("Initiating batched Baseten requests")
176 |
177 | conversations = [self._format_conversation(prompt) for prompt in prompts]
178 | return self.api_adapter._fetch_batched_response(conversations)
179 |
180 | async def _async_generate(self, prompt: str) -> str:
181 | if self.config.exec_async:
182 | return await self.api_adapter._async_fetch_response(prompt.strip())
183 | else:
184 | logger.error("Attempting to run an async request with a synchronous API adapter")
185 |
186 | async def _async_batch_generate(
187 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any
188 | ) -> Coroutine[Any, Any, BatchResult]:
189 | if self.config.exec_async:
190 | cleaned_prompts = [prompt.strip() for prompt in prompts]
191 | return await self.api_adapter._async_fetch_batched_response(cleaned_prompts)
192 | else:
193 | logger.error("Attempting to run an async request with a synchronous API adapter")
194 |
195 |
196 | class BasetenError(Exception):
197 | """Custom exception for Baseten-related errors."""
198 |
199 | def __init__(self, status_code: int, message: str):
200 | """Initialize with a status code and message."""
201 | self.status_code = status_code
202 | self.message = message
203 | super().__init__(self.message)
204 |
--------------------------------------------------------------------------------
/flow_judge/models/common.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from enum import Enum
3 | from typing import Any
4 |
5 | from pydantic import BaseModel, Field
6 |
7 | from .adapters.base import BaseAPIAdapter
8 |
9 |
10 | class BaseFlowJudgeModel(ABC):
11 | """Base class for all FlowJudge models."""
12 |
13 | def __init__(
14 | self, model_id: str, model_type: str, generation_params: dict[str, Any], **kwargs: Any
15 | ) -> None:
16 | """Initialize the base FlowJudge model."""
17 | self.metadata: dict[str, Any] = {
18 | "model_id": model_id,
19 | "model_type": model_type,
20 | "generation_params": generation_params,
21 | "kwargs": kwargs,
22 | }
23 |
24 | @abstractmethod
25 | def _generate(self, prompt: str) -> str:
26 | """Generate a response based on the given prompt."""
27 | pass
28 |
29 | @abstractmethod
30 | def _batch_generate(
31 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any
32 | ) -> list[str]:
33 | """Generate responses for multiple prompts."""
34 | pass
35 |
36 |
37 | class AsyncBaseFlowJudgeModel(ABC):
38 | """Base class for asynchronous FlowJudge models."""
39 |
40 | def __init__(
41 | self, model_id: str, model_type: str, generation_params: dict[str, Any], **kwargs: Any
42 | ) -> None:
43 | """Initialize the base asynchronous FlowJudge model."""
44 | self.metadata: dict[str, Any] = {
45 | "model_id": model_id,
46 | "model_type": model_type,
47 | "generation_params": generation_params,
48 | "kwargs": kwargs,
49 | }
50 |
51 | @abstractmethod
52 | async def _async_generate(self, prompt: str) -> str:
53 | """Generate a response based on the given prompt asynchronously."""
54 | pass
55 |
56 | @abstractmethod
57 | async def _async_batch_generate(
58 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any
59 | ) -> list[str]:
60 | """Generate responses for multiple prompts asynchronously."""
61 | pass
62 |
63 |
64 | class FlowJudgeRemoteModel(BaseFlowJudgeModel):
65 | """Flow judge model class for remote hosting."""
66 |
67 | def __init__(
68 | self,
69 | model_id: str,
70 | model_type: str,
71 | generation_params: dict[str, Any],
72 | api_adapter: BaseAPIAdapter,
73 | **remote_kwargs: Any,
74 | ):
75 | """Initialize the FlowJudge remote model class.
76 |
77 | :param model_id: The ID of the model.
78 | :param model_type: Type of the model based on ModelType.
79 | :param generation_params: Relevant generation params for the model type.
80 | :param remote_kwargs: Keyword arguments to initialize the parameters.
81 | """
82 | super().__init__(model_id, model_type, generation_params, **remote_kwargs)
83 |
84 | if not isinstance(api_adapter, BaseAPIAdapter):
85 | raise ValueError("Invalid Adapter type. Use BaseAPIAdapter.")
86 |
87 | self.api_adapter = api_adapter
88 |
89 | def generate(self, prompt: str) -> str:
90 | """Single generation request."""
91 | conversation = [{"role": "user", "content": prompt.strip()}]
92 | return self.api_adapter.fetch_response(conversation)
93 |
94 | def batch_generate(self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any) -> list[str]:
95 | """Batched generation request."""
96 | conversations = [[{"role": "user", "content": prompt.strip()}] for prompt in prompts]
97 | return self.api_adapter.fetch_batched_response(conversations)
98 |
99 |
100 | class GenerationParams(BaseModel):
101 | """Configuration parameters for text generation."""
102 |
103 | temperature: float = Field(default=0.1, description="Sampling temperature")
104 | top_p: float = Field(default=0.95, description="Top-p sampling parameter")
105 | max_new_tokens: int = Field(
106 | default=1000, description="Maximum number of new tokens to generate"
107 | )
108 | do_sample: bool = Field(default=True, description="Whether to use sampling for generation")
109 |
110 |
111 | class VllmGenerationParams(GenerationParams):
112 | """Configuration parameters specific to VLLM text generation."""
113 |
114 | max_tokens: int | None = None
115 | stop_token_ids: list[int] = [32007, 32001, 32000]
116 |
117 | def __init__(self, **data):
118 | """Initialize VllmGenerationParams with given data.
119 |
120 | :param data: Keyword arguments to initialize the parameters.
121 | """
122 | super().__init__(**data)
123 | self.max_tokens = self.max_new_tokens
124 | del self.max_new_tokens
125 | del self.do_sample
126 |
127 |
128 | class ModelType(Enum):
129 | """Enum for the type of model."""
130 |
131 | TRANSFORMERS = "transformers"
132 | VLLM = "vllm"
133 | VLLM_ASYNC = "vllm_async"
134 | LLAMAFILE = "llamafile"
135 | BASETEN_VLLM = "baseten_vllm"
136 | BASETEN_VLLM_ASYNC = "baseten_vllm_async"
137 |
138 |
139 | class Engine(Enum):
140 | """Enum for the type of engine used for text generation."""
141 |
142 | VLLM: str = "vllm"
143 | VLLM_ASYNC: str = "vllm_async"
144 | HF: str = "hf" # HF stands for Hugging Face (Transformers)
145 | LLAMAFILE: str = "llamafile"
146 |
147 |
148 | class ModelConfig:
149 | """Base configuration for a model."""
150 |
151 | def __init__(
152 | self,
153 | model_id: str,
154 | model_type: ModelType,
155 | generation_params: dict[str, Any],
156 | **kwargs: Any,
157 | ) -> None:
158 | """Initialize ModelConfig with model details and generation parameters.
159 |
160 | :param model_id: Identifier for the model.
161 | :param model_type: Type of the model.
162 | :param generation_params: Parameters for text generation.
163 | :param kwargs: Additional keyword arguments.
164 | """
165 | self.model_id: str = model_id
166 | self.model_type: ModelType = model_type
167 | self.generation_params: dict[str, Any] = generation_params
168 | self.kwargs: dict[str, Any] = kwargs
169 |
--------------------------------------------------------------------------------
/flow_judge/models/huggingface.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import warnings
4 | from typing import Any
5 |
6 | from flow_judge.models.common import BaseFlowJudgeModel, GenerationParams, ModelConfig, ModelType
7 |
8 | try:
9 | import torch
10 | from huggingface_hub import snapshot_download
11 | from transformers import AutoModelForCausalLM, AutoTokenizer
12 |
13 | HF_AVAILABLE = True
14 | except ImportError:
15 | HF_AVAILABLE = False
16 |
17 | from tqdm import tqdm
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class HfConfig(ModelConfig):
23 | """Configuration class for Hugging Face models."""
24 |
25 | _DEFAULT_MODEL_ID = "flowaicom/Flow-Judge-v0.1"
26 |
27 | def __init__(
28 | self,
29 | generation_params: GenerationParams,
30 | device_map: str = "auto",
31 | torch_dtype: str = "bfloat16",
32 | flash_attn: bool = True,
33 | **kwargs: Any,
34 | ):
35 | """Initialize HfConfig with model details and Hugging Face specific parameters.
36 |
37 | :param generation_params: Parameters for text generation.
38 | :param device_map: Device mapping strategy.
39 | :param torch_dtype: PyTorch data type for the model.
40 | :param flash_attn: Whether to use flash attention.
41 | :param kwargs: Additional keyword arguments.
42 | """
43 | model_id = kwargs.pop("_model_id", self._DEFAULT_MODEL_ID)
44 | super().__init__(model_id, ModelType.TRANSFORMERS, generation_params.model_dump(), **kwargs)
45 | self.device_map = device_map
46 | self.torch_dtype = torch_dtype
47 | self.flash_attn = flash_attn
48 | self.kwargs = kwargs
49 |
50 |
51 | class Hf(BaseFlowJudgeModel):
52 | """FlowJudge model class for Hugging Face Transformers."""
53 |
54 | _DEFAULT_MODEL_ID = "flowaicom/Flow-Judge-v0.1"
55 |
56 | def __init__(
57 | self,
58 | generation_params: dict[str, Any] | None = None,
59 | flash_attn: bool = True,
60 | **kwargs: Any,
61 | ):
62 | """Initialize the FlowJudge Hugging Face Transformers model.
63 |
64 | :param generation_params: Dictionary of parameters for text generation.
65 | :param flash_attn: Whether to use flash attention.
66 | :param kwargs: Additional keyword arguments, including:
67 | - _model_id: Identifier for the model. If None, uses the default model.
68 | // ... other kwargs ...
69 | """
70 | if not HF_AVAILABLE:
71 | raise HfError(
72 | status_code=1,
73 | message="The required Hugging Face packages are not installed. "
74 | "Please install them by adding 'hf' to your extras:\n"
75 | "pip install flow-judge[hf]",
76 | )
77 |
78 | model_id = kwargs.pop("_model_id", self._DEFAULT_MODEL_ID)
79 |
80 | if model_id != self._DEFAULT_MODEL_ID:
81 | warnings.warn(
82 | f"The model '{model_id}' is not officially supported. "
83 | f"This library is designed for the '{self._DEFAULT_MODEL_ID}' model. "
84 | "Using other models may lead to unexpected behavior, and we do not handle "
85 | "GitHub issues for unsupported models. Proceed with caution.",
86 | UserWarning,
87 | stacklevel=2,
88 | )
89 |
90 | generation_params = GenerationParams(**(generation_params or {}))
91 |
92 | config = HfConfig(
93 | generation_params=generation_params, flash_attn=flash_attn, _model_id=model_id, **kwargs
94 | )
95 |
96 | super().__init__(model_id, "transformers", config.generation_params, **kwargs)
97 |
98 | try:
99 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
100 |
101 | logger.info(
102 | "Downloading the model from Hugging Face Hub using hf-transfer"
103 | "for faster downloads...",
104 | )
105 | snapshot_download(repo_id=model_id)
106 |
107 | model_kwargs = {
108 | "device_map": config.device_map,
109 | "torch_dtype": getattr(torch, config.torch_dtype),
110 | }
111 | if config.flash_attn:
112 | model_kwargs["attn_implementation"] = "flash_attention_2"
113 |
114 | # Include any additional kwargs that might be relevant for model initialization
115 | for key, value in config.kwargs.items():
116 | if (
117 | key not in model_kwargs
118 | and key in AutoModelForCausalLM.from_pretrained.__code__.co_varnames
119 | ):
120 | model_kwargs[key] = value
121 |
122 | self.model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
123 | self.tokenizer = AutoTokenizer.from_pretrained(model_id)
124 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
125 | self.generation_params = generation_params.model_dump()
126 | self.config = config
127 |
128 | if self.device == "cpu":
129 | logger.warning("Running Hf on CPU may result in longer inference times.")
130 |
131 | self.batch_size = 1 # Default to 1, will be updated in batch_generate
132 |
133 | except Exception as e:
134 | raise HfError(
135 | status_code=2,
136 | message=f"An error occurred while initializing the Hugging Face model: {str(e)}\n"
137 | "Please make sure you have installed all required dependencies by adding 'hf' "
138 | "to your extras:\npip install flow-judge[hf]",
139 | ) from e
140 |
141 | def _determine_batch_size(self, prompts: list[str]) -> int:
142 | """Determine an appropriate batch size based on available GPU memory and eval_inputs.
143 |
144 | This method attempts to find the largest batch size that can be processed without
145 | running out of GPU memory.
146 |
147 | :param prompts: List of input prompts to be processed.
148 | :return: The determined optimal batch size.
149 | """
150 | if self.device == "cpu":
151 | return 1 # Default to 1 for CPU
152 |
153 | batch_size = 1
154 | max_length = self.generation_params.get("max_model_len", 8192)
155 | max_new_tokens = self.generation_params.get("max_new_tokens", 1024)
156 |
157 | while True:
158 | try:
159 | # Prepare a batch of inputs using the longest input
160 | longest_input = max(prompts, key=lambda x: len(self.tokenizer.encode(x)))
161 | # Check if the longest input exceeds max_length
162 | input_length = len(self.tokenizer.encode(longest_input))
163 | if input_length > max_length:
164 | logger.warning(
165 | f"Input length {input_length} exceeds max_length {max_length}. "
166 | f"Truncated inputs can result in suboptimal performance."
167 | )
168 |
169 | inputs = [longest_input] * batch_size
170 | encoded_inputs = self.tokenizer(
171 | inputs,
172 | return_tensors="pt",
173 | padding=True,
174 | truncation=True,
175 | max_length=max_length,
176 | ).to(self.device)
177 |
178 | # Simulate generation
179 | with torch.no_grad():
180 | _ = self.model.generate(
181 | **encoded_inputs, max_new_tokens=max_new_tokens, do_sample=False
182 | )
183 |
184 | # If successful, double the batch size and try again
185 | batch_size *= 2
186 | torch.cuda.empty_cache()
187 | del encoded_inputs
188 | except RuntimeError as e:
189 | if "out of memory" in str(e).lower():
190 | optimal_batch_size = max(1, batch_size // 2)
191 | logger.info(f"Automatically determined batch size: {optimal_batch_size}")
192 | torch.cuda.empty_cache()
193 | return optimal_batch_size
194 | else:
195 | raise
196 |
197 | def _prepare_generation_kwargs(self, **kwargs: Any) -> dict[str, Any]:
198 | """Combines generation params, passed kwargs, and relevant config kwargs.
199 |
200 | :param kwargs: Additional keyword arguments for generation.
201 | :return: A dictionary of prepared generation kwargs.
202 | """
203 | generation_kwargs = {**self.generation_params, **kwargs}
204 | for key, value in self.config.kwargs.items():
205 | if key not in generation_kwargs and key in self.model.generate.__code__.co_varnames:
206 | generation_kwargs[key] = value
207 | return generation_kwargs
208 |
209 | def _generate(self, prompt: str) -> str:
210 | """Generate a response using the FlowJudge Hugging Face Transformers model."""
211 | chat_prompt = self.tokenizer.apply_chat_template(
212 | [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True
213 | )
214 | inputs = self.tokenizer(chat_prompt, return_tensors="pt").to(self.device)
215 | input_length = inputs.input_ids.shape[1]
216 |
217 | generation_kwargs = self._prepare_generation_kwargs()
218 | outputs = self.model.generate(**inputs, **generation_kwargs)
219 |
220 | generated_text = self.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
221 | return generated_text.strip()
222 |
223 | def _batch_generate(
224 | self, prompts: list[str], use_tqdm: bool = True, **kwargs: Any
225 | ) -> list[str]:
226 | """Generate responses for multiple prompts using batching."""
227 | all_results = []
228 |
229 | self.batch_size = self._determine_batch_size(prompts)
230 |
231 | batches = [
232 | prompts[i : i + self.batch_size] for i in range(0, len(prompts), self.batch_size)
233 | ]
234 |
235 | generation_kwargs = self._prepare_generation_kwargs(**kwargs)
236 |
237 | for batch in tqdm(batches, disable=not use_tqdm, desc="Processing batches"):
238 | chat_prompts = [
239 | self.tokenizer.apply_chat_template(
240 | [{"role": "user", "content": prompt}],
241 | tokenize=False,
242 | add_generation_prompt=True,
243 | )
244 | for prompt in batch
245 | ]
246 |
247 | inputs = self.tokenizer(
248 | chat_prompts, return_tensors="pt", padding=True, truncation=True
249 | ).to(self.device)
250 | input_tok_lens = [len(input) for input in inputs["input_ids"]]
251 |
252 | outputs = self.model.generate(**inputs, **generation_kwargs)
253 |
254 | batch_results = []
255 | for output, input_tok_len in zip(outputs, input_tok_lens, strict=True):
256 | result = self.tokenizer.decode(output[input_tok_len:], skip_special_tokens=False)
257 | batch_results.append(result.strip())
258 |
259 | all_results.extend(batch_results)
260 |
261 | return all_results
262 |
263 |
264 | class HfError(Exception):
265 | """Custom exception for Hugging Face-related errors."""
266 |
267 | def __init__(self, status_code: int, message: str):
268 | """Initialize an HfError with a status code and message."""
269 | self.status_code = status_code
270 | self.message = message
271 | super().__init__(self.message)
272 |
--------------------------------------------------------------------------------
/flow_judge/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flowaicom/flow-judge/e56d199db4d79f184dac1e9ab2da83992acda14d/flow_judge/utils/__init__.py
--------------------------------------------------------------------------------
/flow_judge/utils/prompt_formatter.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from flow_judge.metrics.metric import RubricItem
4 |
5 | USER_PROMPT_TEMPLATE = """# GOAL
6 | Your job is to evaluate a task carried out by an AI system powered by a large \
7 | language model.
8 |
9 | You will be provided with the inputs and output of the task, as well as the evaluation criteria \
10 | and scoring rubric. Your task is to evaluate the output of the AI system based on the evaluation \
11 | criteria and scoring rubric provided.
12 |
13 | # INPUT
14 | Below are the inputs required for performing the task:
15 |
16 | {INPUTS}
17 |
18 |
19 | # OUTPUT
20 | Below is the output of the task:
21 |
24 |
25 | # EVALUATION CRITERIA AND SCORING RUBRIC
26 | Here are the evaluation criteria and the rubric that you need to use for evaluating the task:
27 |
28 | {EVALUATION_CRITERIA}
29 |
30 |
31 |
32 | {RUBRIC}
33 |
34 |
35 | # INSTRUCTIONS FOR THE EVALUATION
36 | 1. Understand the task and criteria: Familiarize yourself with the task to be evaluated. \
37 | Review the evaluation criteria and scoring rubric to understand the different levels of \
38 | performance and the descriptions for each score.
39 | 2. Review the inputs and output: Look at the inputs provided for the task. Examine the output \
40 | generated from completing the task.
41 | 3. Compare output to score descriptions: Compare the output against the criteria and score \
42 | descriptions in the scoring rubric. For each criterion,decide which description best matches the \
43 | output.
44 | 4. After comparing the output to the score descriptions, pay attention to the small details that \
45 | might impact the final score that you assign. Sometimes a small difference can dictate the final \
46 | score.
47 | 5. Write verbal feedback justifying your evaluation that includes a detailed rationale, referring \
48 | to specific aspects of the output and comparing them to the rubric.
49 | 6. Assign a final score based on the scoring rubric.
50 |
51 | ## FORMAT FOR THE EVALUATION
52 | - Write the verbal feedback inside tags without any additional surrounding text.
53 | - Write the numeric score inside tags, without any additional surrounding text and always \
54 | after the feedback.
55 |
56 | Please accurately evaluate the task. Strictly adhere to the evaluation criteria and rubric."""
57 |
58 |
59 | USER_PROMPT_NO_INPUTS_TEMPLATE = """# GOAL
60 | Your job is to evaluate a task carried out by an AI system powered by a large language model.
61 |
62 | You will be provided the output of the task, as well as the evaluation criteria \
63 | and scoring rubric. Your task is to evaluate the output of the AI system based on the evaluation \
64 | criteria and scoring rubric provided.
65 |
66 | # OUTPUT
67 | Below is the output of the task:
68 |
71 |
72 | # EVALUATION CRITERIA AND SCORING RUBRIC
73 | Here are the evaluation criteria and the rubric that you need to use for evaluating the task:
74 |
75 | {EVALUATION_CRITERIA}
76 |
77 |
78 |
79 | {RUBRIC}
80 |
81 |
82 | # INSTRUCTIONS FOR THE EVALUATION
83 | 1. Understand the task and criteria: Familiarize yourself with the task to be evaluated. \
84 | Review the evaluation criteria and scoring rubric to understand the different levels of \
85 | performance and the descriptions for each score.
86 | 2. Review the output: Examine the output generated from completing the task.
87 | 3. Compare output to score descriptions: Compare the output against the criteria and score \
88 | descriptions in the scoring rubric. For each criterion,decide which description best matches the \
89 | output.
90 | 4. After comparing the output to the score descriptions, pay attention to the small details that \
91 | might impact the final score that you assign. Sometimes a small difference can dictate the final \
92 | score.
93 | 5. Write verbal feedback justifying your evaluation that includes a detailed rationale, referring \
94 | to specific aspects of the output and comparing them to the rubric.
95 | 6. Assign a final score based on the scoring rubric.
96 |
97 | ## FORMAT FOR THE EVALUATION
98 | - Write the verbal feedback inside tags without any additional surrounding text.
99 | - Write the numeric score inside tags, without any additional surrounding text and always \
100 | after the feedback.
101 |
102 | Please accurately evaluate the task. Strictly adhere to the evaluation criteria and rubric."""
103 |
104 |
105 | def format_vars(variables: list[dict[str, str]]) -> str:
106 | """Format variables for the prompt."""
107 | var_strs = []
108 | for var in variables:
109 | for key, value in var.items():
110 | var_tag = key.lower().replace(" ", "_")
111 | var_strs.append(f"<{var_tag}>\n{value}\n{var_tag}>")
112 | return "\n".join(var_strs)
113 |
114 |
115 | def format_rubric(rubric: list[RubricItem]) -> str:
116 | """Format the rubric for the prompt."""
117 | rubric_strs = []
118 |
119 | # Sort rubric items by score, lowest to highest
120 | sorted_rubric = sorted(rubric, key=lambda x: x.score)
121 |
122 | for item in sorted_rubric:
123 | rubric_strs.append(f"- Score {item.score}: {item.description}")
124 | return "\n".join(rubric_strs)
125 |
126 |
127 | def format_user_prompt(prompt_variables: dict[str, Any]) -> str:
128 | """Format the user prompt based on provided variables."""
129 | if prompt_variables["INPUTS"]:
130 | return USER_PROMPT_TEMPLATE.format(**prompt_variables)
131 | else:
132 | return USER_PROMPT_NO_INPUTS_TEMPLATE.format(**prompt_variables)
133 |
--------------------------------------------------------------------------------
/flow_judge/utils/result_writer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import re
4 | from datetime import datetime, timezone
5 | from enum import Enum
6 | from pathlib import Path
7 | from typing import Any
8 |
9 | from pydantic import BaseModel
10 |
11 | import flow_judge
12 | from flow_judge.eval_data_types import EvalInput, EvalOutput
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def write_results_to_disk(
18 | eval_inputs: list[EvalInput],
19 | eval_outputs: list[EvalOutput],
20 | model_metadata: dict[str, Any],
21 | metric_name: str,
22 | output_dir: str | Path,
23 | append: bool = False,
24 | ) -> None:
25 | """Write evaluation results, inputs, and metadata to separate JSONL files.
26 |
27 | This function processes evaluation data and writes it to disk in a structured format.
28 | It creates separate files for metadata and results, organizing them in directories
29 | based on the metric name and model ID.
30 |
31 | Args:
32 | eval_inputs: List of evaluation inputs.
33 | eval_outputs: List of evaluation outputs.
34 | model_metadata: Dictionary containing model metadata.
35 | metric_name: Name of the metric being evaluated.
36 | output_dir: Directory to write output files.
37 | append: If True, append results to existing file. If False, overwrite. Default is False.
38 |
39 | Raises:
40 | ValueError: If inputs are invalid, empty, or lists have different lengths.
41 | KeyError: If required keys are missing from model_metadata.
42 | OSError: If there are file system related errors during writing.
43 |
44 | Note:
45 | - Ensures eval_inputs and eval_outputs have the same length.
46 | - Creates necessary directories if they don't exist.
47 | - Handles special characters in metric_name and model_id for file naming.
48 | - Overwrites existing files with the same name without warning.
49 | """
50 | _validate_inputs(eval_inputs, eval_outputs, model_metadata, metric_name)
51 |
52 | fmt_metric_name = _format_name(metric_name)
53 | fmt_model_id = _format_name(model_metadata["model_id"])
54 | timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S.%f")[:-3]
55 | base_filename = f"{fmt_metric_name}_{fmt_model_id}_{model_metadata['model_type']}_{timestamp}"
56 | paths = _prepare_file_paths(output_dir, fmt_metric_name, fmt_model_id, base_filename)
57 | metadata = _prepare_metadata(model_metadata, timestamp)
58 |
59 | try:
60 | _write_json_file(paths["metadata"], metadata)
61 |
62 | mode = "a" if append else "w"
63 | with paths["results"].open(mode, encoding="utf-8") as f:
64 | for eval_input, eval_output in zip(eval_inputs, eval_outputs, strict=True):
65 | result = {
66 | "sample": eval_input.model_dump(),
67 | "feedback": eval_output.feedback,
68 | "score": eval_output.score,
69 | }
70 | f.write(json.dumps(result, ensure_ascii=False) + "\n")
71 |
72 | logger.info(f"Results {'appended to' if append else 'saved to'} {paths['results']}")
73 | except OSError as e:
74 | logger.error(f"Error writing files: {e}")
75 | raise
76 |
77 |
78 | def _validate_inputs(
79 | eval_inputs: list[EvalInput],
80 | eval_outputs: list[EvalOutput],
81 | model_metadata: dict[str, Any],
82 | metric_name: str,
83 | ) -> None:
84 | """Validate input parameters for the write_results_to_disk function.
85 |
86 | Args:
87 | eval_inputs: List of evaluation inputs.
88 | eval_outputs: List of evaluation outputs.
89 | model_metadata: Dictionary containing model metadata.
90 | metric_name: Name of the metric being evaluated.
91 |
92 | Raises:
93 | ValueError: If eval_inputs or eval_outputs are empty, have different lengths,
94 | or if metric_name is empty or only whitespace.
95 | KeyError: If required keys ('model_id', 'model_type') are missing from
96 | model_metadata.
97 |
98 | Note:
99 | This function does not validate the content of eval_inputs or eval_outputs,
100 | only their presence and length.
101 | """
102 | if not eval_inputs or not eval_outputs:
103 | raise ValueError("eval_inputs and eval_outputs cannot be empty")
104 | if len(eval_inputs) != len(eval_outputs):
105 | raise ValueError("eval_inputs and eval_outputs must have the same length")
106 | if not metric_name or not metric_name.strip():
107 | raise ValueError("metric_name cannot be empty or only whitespace")
108 | required_keys = {"model_id", "model_type"}
109 | missing_keys = required_keys - set(model_metadata.keys())
110 | if missing_keys:
111 | raise KeyError(f"model_metadata missing required keys: {missing_keys}")
112 |
113 |
114 | def _format_name(name: str) -> str:
115 | """Format a name for use in file paths by removing special characters.
116 |
117 | Args:
118 | name: The name to format.
119 |
120 | Returns:
121 | A formatted string safe for use in file paths.
122 |
123 | Note:
124 | This function replaces spaces with underscores, removes non-alphanumeric
125 | characters (except underscore and hyphen), and replaces non-ASCII
126 | characters with underscores.
127 | """
128 | # Replace spaces with underscores
129 | name = name.replace(" ", "_")
130 | # Remove any character that is not alphanumeric, underscore, or hyphen
131 | name = re.sub(r"[^\w\-]", "", name)
132 | # Replace any non-ASCII character with underscore
133 | name = re.sub(r"[^\x00-\x7F]", "_", name)
134 | return name
135 |
136 |
137 | def _prepare_file_paths(
138 | output_dir: str | Path,
139 | fmt_metric_name: str,
140 | fmt_model_id: str,
141 | base_filename: str,
142 | ) -> dict[str, Path]:
143 | """Prepare file paths for metadata and results files.
144 |
145 | Args:
146 | output_dir: Base output directory.
147 | fmt_metric_name: Formatted metric name.
148 | fmt_model_id: Formatted model ID.
149 | base_filename: Base filename for output files.
150 |
151 | Returns:
152 | A dictionary containing paths for metadata and results files.
153 |
154 | Note:
155 | This function creates the necessary directories if they don't exist.
156 | It does not check if the resulting file paths already exist.
157 | """
158 | output_dir = Path(output_dir)
159 | metric_folder = output_dir / fmt_metric_name
160 | metadata_folder = metric_folder / f"metadata_{fmt_metric_name}_{fmt_model_id}"
161 | metadata_folder.mkdir(parents=True, exist_ok=True)
162 |
163 | return {
164 | "metadata": metadata_folder / f"metadata_{base_filename}.json",
165 | "results": metric_folder / f"results_{base_filename}.jsonl",
166 | }
167 |
168 |
169 | def _prepare_metadata(model_metadata: dict[str, Any], timestamp: str) -> dict[str, Any]:
170 | """Prepare metadata dictionary for writing.
171 |
172 | Args:
173 | model_metadata: Dictionary containing model metadata.
174 | timestamp: Timestamp string.
175 |
176 | Returns:
177 | A dictionary containing prepared metadata.
178 |
179 | Note:
180 | - Adds 'library_version' and 'timestamp' to the metadata.
181 | - Converts Pydantic BaseModel instances to dictionaries.
182 | - Converts Enum instances to their values.
183 | - Does not deep copy the input model_metadata.
184 | """
185 | metadata = {
186 | "library_version": f"{flow_judge.__version__}",
187 | "timestamp": timestamp,
188 | **model_metadata,
189 | }
190 | for key, item in metadata.items():
191 | if isinstance(item, BaseModel):
192 | metadata[key] = item.model_dump()
193 | elif isinstance(item, Enum):
194 | metadata[key] = item.value
195 | return metadata
196 |
197 |
198 | def _write_json_file(path: Path, data: dict[str, Any]) -> None:
199 | """Write data to a JSON file.
200 |
201 | Args:
202 | path: Path to the output file.
203 | data: Data to write to the file.
204 |
205 | Raises:
206 | OSError: If there's an error writing to the file.
207 |
208 | Note:
209 | - Uses UTF-8 encoding.
210 | - Overwrites the file if it already exists.
211 | - Ensures non-ASCII characters are preserved in the output.
212 | """
213 | with path.open("w", encoding="utf-8") as f:
214 | json.dump(data, f, ensure_ascii=False, indent=2)
215 |
216 |
217 | def _write_results_file(
218 | path: Path, eval_inputs: list[EvalInput], eval_outputs: list[EvalOutput], append: bool = False
219 | ) -> None:
220 | """Write results to a JSONL file.
221 |
222 | Args:
223 | path: Path to the output file.
224 | eval_inputs: List of evaluation inputs.
225 | eval_outputs: List of evaluation outputs.
226 | append: If True, append to the file. If False, overwrite. Default is False.
227 |
228 | Raises:
229 | OSError: If there's an error writing to the file.
230 | ValueError: If eval_inputs and eval_outputs have different lengths.
231 |
232 | Note:
233 | - Uses UTF-8 encoding.
234 | - Appends to the file if append is True, otherwise overwrites.
235 | - Each line in the file is a JSON object representing one result.
236 | - Ensures non-ASCII characters are preserved in the output.
237 | """
238 | if len(eval_inputs) != len(eval_outputs):
239 | raise ValueError("eval_inputs and eval_outputs must have the same length")
240 |
241 | mode = "a" if append else "w"
242 | with path.open(mode, encoding="utf-8") as f:
243 | for input_data, eval_output in zip(eval_inputs, eval_outputs, strict=True):
244 | result = {
245 | "sample": input_data.model_dump(),
246 | "feedback": eval_output.feedback,
247 | "score": eval_output.score,
248 | }
249 | f.write(json.dumps(result, ensure_ascii=False) + "\n")
250 |
--------------------------------------------------------------------------------
/flow_judge/utils/validators.py:
--------------------------------------------------------------------------------
1 | from flow_judge.eval_data_types import EvalInput
2 | from flow_judge.metrics.metric import CustomMetric, Metric
3 |
4 |
5 | def validate_eval_input(eval_input: EvalInput, metric: Metric | CustomMetric):
6 | """Validate that the EvalInput matches the required inputs and output in the metric."""
7 | input_keys = {list(input_dict.keys())[0] for input_dict in eval_input.inputs}
8 | output_key = list(eval_input.output.keys())[0]
9 | required_inputs = set(metric.required_inputs)
10 |
11 | if input_keys != required_inputs:
12 | raise ValueError(f"Input keys {input_keys} do not match required inputs {required_inputs}")
13 |
14 | if metric.required_output:
15 | if not hasattr(eval_input, "output"):
16 | raise ValueError(f"Required output '{metric.required_output}' is missing")
17 | elif metric.required_output != output_key:
18 | raise ValueError(
19 | f"""Output key '{output_key}' does not match \
20 | required output '{metric.required_output}'"""
21 | )
22 |
--------------------------------------------------------------------------------
/img/flow_judge_banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flowaicom/flow-judge/e56d199db4d79f184dac1e9ab2da83992acda14d/img/flow_judge_banner.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "flow-judge"
7 | version = "0.1.2"
8 | description = "A small yet powerful LM Judge"
9 | readme = "README.md"
10 | authors = [
11 | {name = "Bernardo Garcia", email = "bernardo@flow-ai.com"},
12 | {name = "Karolus Sariola", email = "karolus@flow-ai.com"},
13 | {name = "Minaam Shahid", email = "minaam@flow-ai.com"},
14 | {name = "Tiina Vaahtio", email = "tiina@flow-ai.com"},
15 | {name = "Alex Wegrzyn", email = "alex.wegrzyn@flow-ai.com"},
16 | ]
17 | license = {file = "LICENSE"}
18 | classifiers = [
19 | "Programming Language :: Python :: 3",
20 | "License :: OSI Approved :: Apache Software License",
21 | "Operating System :: OS Independent",
22 | ]
23 | keywords = ["LM-judge", "evaluation", "LLMs", "AI", "benchmarking"]
24 | requires-python = ">=3.10"
25 | dependencies = [
26 | "pydantic>=2.9.1",
27 | "requests>=2.32.3",
28 | "hf-transfer>=0.1.1",
29 | "ipykernel>=6.29.0",
30 | "ipywidgets>=8.1.0",
31 | "tqdm>=4.66.1",
32 | "structlog",
33 | ]
34 |
35 | [project.optional-dependencies]
36 | dev = [
37 | "pytest",
38 | "pre-commit",
39 | "ruff",
40 | "black",
41 | "isort",
42 | "pytest-cov",
43 | "codecov",
44 | "mypy>=1.11.2",
45 | "types-requests",
46 | "types-tqdm",
47 | "memray>=1.14.0",
48 | "pytest-memray>=1.7.0",
49 | "pytest-asyncio>=0.23.6, <0.24.0",
50 | "hypothesis"
51 | ]
52 | integrations-test = [
53 | "llama-index",
54 | "llama-index-embeddings-huggingface"
55 | ]
56 | hf = [
57 | "transformers>=4.45.0",
58 | "torch>=2.3.0",
59 | "bitsandbytes>=0.41.0,<=0.42.0",
60 | "accelerate>=0.34.2",
61 | ]
62 | vllm = ["vllm==0.6.2"]
63 | llamafile = [
64 | "torch>=2.3.0",
65 | "openai>=1.51.0",
66 | ]
67 | baseten = [
68 | "truss>=0.9.44",
69 | "openai>=1.51.0",
70 | "aiohttp>=3.10.5"
71 | ]
72 |
73 | [project.urls]
74 | Homepage = "https://github.com/flowaicom/flow-judge"
75 |
76 | [tool.setuptools]
77 | packages = ["flow_judge", "flow_judge.integrations", "flow_judge.metrics", "flow_judge.models", "flow_judge.utils"]
78 |
79 | [tool.setuptools.package-data]
80 | "flow_judge.models" = ["adapters/baseten/**/*.yaml"]
81 |
82 | [tool.setuptools_scm]
83 | version_scheme = "python-simplified-semver"
84 |
85 | [tool.ruff]
86 | line-length = 100
87 | include = ["flow_judge/**/*.py", "tests/**/*.py", "setup.py"]
88 |
89 | [tool.ruff.lint]
90 | select = ["E", "F", "I", "N", "W", "B", "C", "D"]
91 | ignore = ["D100", "D104"]
92 |
93 | [tool.ruff.lint.per-file-ignores]
94 | "__init__.py" = ["F401"]
95 |
96 | [tool.ruff.lint.pydocstyle]
97 | convention = "google"
98 |
99 | [tool.black]
100 | line-length = 100
101 | target-version = ['py311']
102 | include = '(flow_judge/.*\.py$|tests/.*\.py$|setup\.py)'
103 |
104 | [tool.isort]
105 | profile = "black"
106 | line_length = 100
107 | src_paths = ["flow_judge", "tests"]
108 |
109 | [tool.mypy]
110 | warn_unused_configs = true
111 | warn_redundant_casts = true
112 | warn_unused_ignores = true
113 | strict_equality = true
114 | check_untyped_defs = true
115 | disallow_any_generics = true
116 | disallow_untyped_defs = false
117 | disallow_incomplete_defs = false
118 |
119 | [tool.bdist_wheel]
120 | universal = true
121 |
122 | [tool.pytest.ini_options]
123 | asyncio_mode = "auto"
124 | markers = [
125 | "asyncio: mark test as an asyncio coroutine",
126 | "memray: marks tests to be run with memray profiling",
127 | "e2e: marks end-to-end tests",
128 | ]
129 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Tests for Flow Judge
3 |
4 | This directory contains the test suite for the Flow Judge project.
5 |
6 | ## Test Coverage
7 |
8 | Below is the current test coverage visualization for the Flow Judge project:
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Running Tests
17 |
18 | To run the entire test suite:
19 | ```sh
20 | pytest
21 | ```
22 | To run a specific test file:
23 | ```sh
24 | pytest tests/unit/test_flow_judge.py
25 | ```
26 | To run tests with coverage report:
27 | ```sh
28 | pytest --cov=flow_judge --cov-report=term-missing
29 | ```
30 |
31 | ## Contributing
32 |
33 | When adding new features or modifying existing ones, please make sure to add or update the corresponding tests. This helps maintain the project's reliability and makes it easier to catch potential issues early.
34 |
35 | ## Continuous Integration
36 |
37 | Our CI pipeline automatically runs these tests on every pull request and push to the main branch. You can check the status of the latest runs in the GitHub Actions tab of the repository.
38 |
--------------------------------------------------------------------------------
/tests/e2e-cloud-gpu/models/adapters/test_baseten_e2e.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import statistics
4 | from collections import Counter
5 | from pathlib import Path
6 | from typing import Any
7 |
8 | import pytest
9 | from llama_index.core import VectorStoreIndex
10 | from llama_index.core.evaluation import BatchEvalRunner
11 | from llama_index.core.llama_dataset import download_llama_dataset
12 | from pydantic import BaseModel
13 |
14 | from flow_judge import Baseten
15 | from flow_judge.integrations.llama_index import LlamaIndexFlowJudge
16 | from flow_judge.metrics import CustomMetric, RubricItem
17 |
18 | pytest_plugins = ("pytest_asyncio",)
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class TestConfig(BaseModel):
24 | """Configuration for Baseten e2e tests."""
25 |
26 | api_key: str = os.getenv("BASETEN_API_KEY")
27 | model_id: str = os.getenv("BASETEN_MODEL_ID")
28 | webhook_url: str = os.getenv("BASETEN_WEBHOOK_URL")
29 | webhook_secret: str = os.getenv("BASETEN_WEBHOOK_SECRET")
30 |
31 |
32 | @pytest.fixture(scope="module")
33 | def test_config() -> TestConfig:
34 | """Fixture to load test configuration from environment variables.
35 |
36 | Returns:
37 | TestConfig: Configuration object for Baseten e2e tests.
38 |
39 | Raises:
40 | ValueError: If any required environment variable is missing.
41 | """
42 | try:
43 | return TestConfig()
44 | except ValueError as e:
45 | pytest.fail(f"Missing required environment variable: {str(e)}")
46 |
47 |
48 | @pytest.fixture(scope="module")
49 | def test_cache_dir() -> Path:
50 | """Create a temporary directory for test cache.
51 |
52 | Returns:
53 | Path: Path object pointing to the temporary directory.
54 | """
55 | import tempfile
56 |
57 | with tempfile.TemporaryDirectory(prefix="flow-judge-baseten-test-") as tmpdir:
58 | temp_path = Path(tmpdir)
59 | logger.info(f"Created temporary test cache directory: {temp_path}")
60 | yield temp_path
61 | logger.info(f"Cleaned up temporary test cache directory: {temp_path}")
62 |
63 |
64 | @pytest.fixture
65 | def correctness_metric() -> CustomMetric:
66 | """Creates a CustomMetric for evaluating the correctness of generated answers.
67 |
68 | Returns:
69 | CustomMetric: A metric object with evaluation criteria and rubric for
70 | assessing answer correctness.
71 | """
72 | evaluation_criteria = "Is the generated answer relevant to the user query and reference answer?"
73 | rubric = [
74 | RubricItem(
75 | score=1,
76 | description="The generated answer is not relevant to the user query "
77 | "and reference answer.",
78 | ),
79 | RubricItem(
80 | score=2,
81 | description="The generated answer is according to reference answer but"
82 | " not relevant to user query.",
83 | ),
84 | RubricItem(
85 | score=3,
86 | description="The generated answer is relevant to the user query and "
87 | "reference answer but contains mistakes.",
88 | ),
89 | RubricItem(
90 | score=4,
91 | description="The generated answer is relevant to the user query and "
92 | "has the exact same metrics as the reference answer, but it is not as concise.",
93 | ),
94 | RubricItem(
95 | score=5,
96 | description="The generated answer is relevant to the user query and "
97 | "fully correct according to the reference answer.",
98 | ),
99 | ]
100 | return CustomMetric(
101 | name="correctness",
102 | criteria=evaluation_criteria,
103 | rubric=rubric,
104 | required_inputs=["query", "reference"],
105 | required_output="response",
106 | )
107 |
108 |
109 | def get_scores_distribution(scores: list[float]) -> dict[float, str]:
110 | """Calculates the distribution of scores as percentages.
111 |
112 | Args:
113 | scores (List[float]): A list of numerical scores.
114 |
115 | Returns:
116 | Dict[float, str]: A dictionary mapping scores to their percentage occurrence.
117 | """
118 | score_counts = Counter(scores)
119 | total_scores = len(scores)
120 | return {score: f"{(count / total_scores) * 100:.1f}%" for score, count in score_counts.items()}
121 |
122 |
123 | def compare_distributions(
124 | actual: dict[float, str],
125 | expected: dict[float, str],
126 | tolerance: float = 20.0,
127 | ) -> bool:
128 | """Compares two score distributions within a given tolerance.
129 |
130 | Args:
131 | actual (Dict[float, str]): The actual score distribution.
132 | expected (Dict[float, str]): The expected score distribution.
133 | tolerance (float, optional): The maximum allowed difference between
134 | percentages. Defaults to 20.0.
135 |
136 | Returns:
137 | bool: True if the distributions are within the tolerance, False otherwise.
138 | """
139 | for score in set(actual.keys()) | set(expected.keys()):
140 | actual_pct = float(actual.get(score, "0%").rstrip("%"))
141 | expected_pct = float(expected.get(score, "0%").rstrip("%"))
142 | if abs(actual_pct - expected_pct) > tolerance:
143 | return False
144 | return True
145 |
146 |
147 | async def batch_eval_runner(
148 | evaluators: dict[str, LlamaIndexFlowJudge],
149 | query_engine: Any,
150 | questions: list[str],
151 | reference: list[str] | None = None,
152 | num_workers: int = 2,
153 | ) -> dict[str, list[Any]]:
154 | """Runs batch evaluation using the provided evaluators and query engine.
155 |
156 | Args:
157 | evaluators (Dict[str, LlamaIndexFlowJudge]): Dictionary of evaluators.
158 | query_engine (Any): The query engine to use for generating responses.
159 | questions (List[str]): List of questions to evaluate.
160 | reference (Optional[List[str]], optional): List of reference answers.
161 | Defaults to None.
162 | num_workers (int, optional): Number of workers for parallel processing.
163 | Defaults to 2.
164 |
165 | Returns:
166 | Dict[str, List[Any]]: Evaluation results for each evaluator.
167 | """
168 | batch_runner = BatchEvalRunner(evaluators, workers=num_workers, show_progress=True)
169 | return await batch_runner.aevaluate_queries(
170 | query_engine, queries=questions, reference=reference
171 | )
172 |
173 |
174 | @pytest.mark.asyncio
175 | async def test_baseten_correctness_evaluation(
176 | test_config: TestConfig,
177 | correctness_metric: CustomMetric,
178 | test_cache_dir: Path,
179 | ) -> None:
180 | """Tests the correctness evaluation of Baseten model using LlamaIndexFlowJudge.
181 |
182 | Args:
183 | test_config (TestConfig): Test configuration object.
184 | correctness_metric (CustomMetric): The metric used for evaluation.
185 | test_cache_dir (Path): Temporary directory for test cache.
186 |
187 | Raises:
188 | AssertionError: If the evaluation score is outside the expected range or
189 | feedback is missing.
190 | """
191 | os.environ["HF_HOME"] = str(test_cache_dir)
192 | model = Baseten(
193 | _model_id=test_config.model_id,
194 | exec_async=True,
195 | webhook_proxy_url=test_config.webhook_url,
196 | )
197 | flow_judge_evaluator = LlamaIndexFlowJudge(model=model, metric=correctness_metric)
198 |
199 | # Download and prepare the dataset
200 | rag_dataset, documents = download_llama_dataset(
201 | "MiniTruthfulQADataset", str(test_cache_dir / "mini_truthful_qa")
202 | )
203 |
204 | # Select a single example for evaluation
205 | example = rag_dataset.examples[0]
206 | query, reference = example.query, example.reference_answer
207 |
208 | # Generate response using Baseten model
209 | response = await model._async_generate(query)
210 |
211 | result = await flow_judge_evaluator.aevaluate(
212 | query=query, reference=reference, response=response
213 | )
214 |
215 | assert result is not None, "Evaluation result is None"
216 | assert 2 <= int(result.score) <= 5, f"Score {result.score} is out of expected range"
217 | assert result.feedback is not None, "Feedback is missing"
218 |
219 | logger.info(f"Evaluation score: {result.score}")
220 | logger.info(f"Evaluation feedback: {result.feedback}")
221 |
222 |
223 | @pytest.mark.asyncio
224 | async def test_baseten_batch_evaluation(
225 | test_config: TestConfig,
226 | correctness_metric: CustomMetric,
227 | test_cache_dir: Path,
228 | ) -> None:
229 | """Performs a batch evaluation of queries using Baseten model and analyzes results.
230 |
231 | Args:
232 | test_config (TestConfig): Test configuration object.
233 | correctness_metric (CustomMetric): The metric used for evaluation.
234 | test_cache_dir (Path): Temporary directory for test cache.
235 |
236 | Raises:
237 | AssertionError: If the evaluation results do not meet expected criteria.
238 | """
239 | os.environ["HF_HOME"] = str(test_cache_dir)
240 | model = Baseten(
241 | _model_id=test_config.model_id,
242 | exec_async=True,
243 | webhook_proxy_url=test_config.webhook_url,
244 | )
245 | logger.info("Starting test_baseten_batch_evaluation")
246 |
247 | flow_judge_correctness = LlamaIndexFlowJudge(model=model, metric=correctness_metric)
248 |
249 | # Download and prepare the dataset
250 | rag_dataset, documents = download_llama_dataset(
251 | "MiniTruthfulQADataset", str(test_cache_dir / "mini_truthful_qa")
252 | )
253 |
254 | # Create the index and query engine
255 | index = VectorStoreIndex.from_documents(documents=documents)
256 | query_engine = index.as_query_engine()
257 |
258 | # Prepare queries and references
259 | rag_subset = rag_dataset.examples[:10]
260 | queries = [example.query for example in rag_subset]
261 | references = [example.reference_answer for example in rag_subset]
262 |
263 | logger.info(f"Evaluating {len(queries)} queries")
264 |
265 | evaluators = {"correctness": flow_judge_correctness}
266 |
267 | eval_results = await batch_eval_runner(
268 | evaluators=evaluators,
269 | query_engine=query_engine,
270 | questions=queries,
271 | reference=references,
272 | )
273 |
274 | # Check results
275 | assert "correctness" in eval_results, "Correctness evaluator results missing"
276 | assert len(eval_results["correctness"]) == len(queries), "Incomplete evaluation results"
277 |
278 | for result in eval_results["correctness"]:
279 | assert result.score is not None, "Evaluation score is missing"
280 | assert result.feedback is not None, "Evaluation feedback is missing"
281 |
282 | # Calculate score distribution
283 | scores = [result.score for result in eval_results["correctness"]]
284 | actual_distribution = get_scores_distribution(scores)
285 | logger.info(f"Actual score distribution: {actual_distribution}")
286 |
287 | # Calculate average score
288 | average_score = statistics.mean(scores)
289 | logger.info(f"Average score: {average_score:.2f}")
290 |
291 | # Assert that the average score is within an acceptable range
292 | assert (
293 | 3.0 <= average_score <= 4.5
294 | ), f"Average score {average_score:.2f} is outside the expected range of 3.0 to 4.5"
295 |
296 | # Check that we have a variety of scores
297 | unique_scores = set(scores)
298 | assert (
299 | len(unique_scores) >= 3
300 | ), f"Expected at least 3 different score values, but got {len(unique_scores)}"
301 |
302 | logger.info("test_baseten_batch_evaluation completed successfully")
303 |
304 |
305 | if __name__ == "__main__":
306 | pytest.main([__file__, "-v", "-s"])
307 |
--------------------------------------------------------------------------------
/tests/e2e-local/models/test_llamafile_e2e.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import tempfile
4 | from pathlib import Path
5 |
6 | import pytest
7 |
8 | from flow_judge.models.llamafile import Llamafile
9 |
10 | # Set up logging with more detail
11 | logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
12 | logger = logging.getLogger(__name__)
13 |
14 | # Define MEMORY_LIMIT (10 GB in bytes)
15 | MEMORY_LIMIT = 10 * 1024 * 1024 * 1024 # 10 GB
16 |
17 | """
18 | End-to-end test suite for the Llamafile class.
19 |
20 | This test suite performs comprehensive tests on the Llamafile class,
21 | including initialization, file operations, server management, command-line
22 | argument passing, and text generation. These tests interact with the actual
23 | Llamafile system and may take longer to run compared to unit tests.
24 | """
25 |
26 |
27 | @pytest.fixture(scope="module")
28 | def test_cache_dir():
29 | """Create a temporary directory for test cache.
30 |
31 | This fixture creates a temporary directory that is guaranteed to be
32 | writable and cleaned up after the tests.
33 |
34 | :yield: Path object pointing to the temporary directory
35 | :rtype: pathlib.Path
36 | """
37 | with tempfile.TemporaryDirectory(prefix="flow-judge-test-") as tmpdir:
38 | temp_path = Path(tmpdir)
39 | logger.info(f"Created temporary test cache directory: {temp_path}")
40 | yield temp_path
41 | logger.info(f"Cleaned up temporary test cache directory: {temp_path}")
42 |
43 |
44 | @pytest.mark.memray(threshold=MEMORY_LIMIT)
45 | def test_llamafile_initialization(test_cache_dir):
46 | """Test Llamafile initialization with various parameters.
47 |
48 | This test verifies that a Llamafile instance is correctly initialized
49 | with custom configuration parameters, including port, context size,
50 | GPU layers, thread count, batch size, concurrent requests, and
51 | generation parameters.
52 |
53 | :param test_cache_dir: Path to the test cache directory
54 | :type test_cache_dir: pathlib.Path
55 | """
56 | logger.info("Starting Llamafile initialization test")
57 | llamafile = Llamafile(
58 | cache_dir=str(test_cache_dir),
59 | port=9000,
60 | context_size=4096,
61 | gpu_layers=20,
62 | thread_count=4,
63 | batch_size=16,
64 | max_concurrent_requests=2,
65 | generation_params={
66 | "temperature": 0.8,
67 | "top_p": 0.9,
68 | "max_new_tokens": 500,
69 | },
70 | )
71 |
72 | assert llamafile.config.model_id == "flowaicom/Flow-Judge-v0.1-Llamafile"
73 | assert llamafile.config.port == 9000
74 | assert llamafile.config.context_size == 4096
75 | assert llamafile.config.gpu_layers == 20
76 | assert llamafile.config.thread_count == 4
77 | assert llamafile.config.batch_size == 16
78 | assert llamafile.config.max_concurrent_requests == 2
79 | assert llamafile.config.generation_params.temperature == 0.8
80 | assert llamafile.config.generation_params.top_p == 0.9
81 | assert llamafile.config.generation_params.max_new_tokens == 500
82 | logger.info("Llamafile initialization test completed successfully")
83 |
84 |
85 | @pytest.mark.memray(threshold=MEMORY_LIMIT)
86 | def test_download_llamafile(test_cache_dir):
87 | """Test downloading of Llamafile.
88 |
89 | This test ensures that the Llamafile is correctly downloaded to the
90 | specified cache directory and that the downloaded file has the correct
91 | permissions (executable).
92 |
93 | :param test_cache_dir: Path to the test cache directory
94 | :type test_cache_dir: pathlib.Path
95 | """
96 | logger.info("Starting Llamafile download test")
97 | llamafile = Llamafile(cache_dir=str(test_cache_dir))
98 | llamafile_path = llamafile.download_llamafile()
99 |
100 | assert os.path.exists(llamafile_path), f"Llamafile not found at {llamafile_path}"
101 | assert os.access(llamafile_path, os.X_OK), f"Llamafile at {llamafile_path} is not executable"
102 | logger.info(f"Llamafile successfully downloaded to {llamafile_path}")
103 |
104 |
105 | @pytest.mark.memray(threshold=MEMORY_LIMIT)
106 | def test_build_llamafile_command(test_cache_dir):
107 | """Test building of Llamafile command.
108 |
109 | This test verifies that the Llamafile command is correctly constructed
110 | based on the provided configuration parameters. It checks for the
111 | presence of essential command-line arguments in the built command.
112 |
113 | :param test_cache_dir: Path to the test cache directory
114 | :type test_cache_dir: pathlib.Path
115 | """
116 | logger.info("Starting Llamafile command building test")
117 | llamafile = Llamafile(
118 | cache_dir=str(test_cache_dir),
119 | port=9000,
120 | context_size=4096,
121 | gpu_layers=20,
122 | thread_count=4,
123 | batch_size=16,
124 | max_concurrent_requests=2,
125 | generation_params={"temperature": 0.8, "max_new_tokens": 500},
126 | quantized_kv=True,
127 | flash_attn=True,
128 | )
129 |
130 | llamafile_path = llamafile.download_llamafile()
131 | command = llamafile._build_llamafile_command(llamafile_path)
132 |
133 | expected_args = [
134 | llamafile_path,
135 | "--port 9000",
136 | "-c 4096",
137 | "-ngl 20",
138 | "--threads 4",
139 | "-b 16",
140 | "--parallel 2",
141 | "--temp 0.8",
142 | "-n 500",
143 | "-ctk q4_0",
144 | "-ctv q4_0",
145 | "-fa",
146 | ]
147 | for arg in expected_args:
148 | assert arg in command, f"Expected argument '{arg}' not found in command"
149 | logger.info("Llamafile command successfully built and verified")
150 |
151 |
152 | @pytest.mark.memray(threshold=MEMORY_LIMIT)
153 | def test_start_stop_server(test_cache_dir):
154 | """Test starting and stopping of Llamafile server.
155 |
156 | This test verifies that the Llamafile server can be started and stopped
157 | correctly, and that the is_server_running method accurately reflects
158 | the server's state.
159 |
160 | :param test_cache_dir: Path to the test cache directory
161 | :type test_cache_dir: pathlib.Path
162 | """
163 | logger.info("Starting Llamafile server start/stop test")
164 | llamafile = Llamafile(cache_dir=str(test_cache_dir))
165 |
166 | try:
167 | llamafile.start_llamafile_server()
168 | assert llamafile.is_server_running(), "Server should be running but is not"
169 | logger.info("Llamafile server successfully started")
170 |
171 | llamafile.stop_llamafile_server()
172 | assert not llamafile.is_server_running(), "Server should not be running but is"
173 | logger.info("Llamafile server successfully stopped")
174 | except Exception as e:
175 | logger.error(f"Error during server start/stop test: {str(e)}")
176 | if llamafile.llamafile_process:
177 | stdout, stderr = llamafile.llamafile_process.communicate()
178 | logger.error(f"Process stdout: {stdout}")
179 | logger.error(f"Process stderr: {stderr}")
180 | raise
181 |
182 |
183 | @pytest.mark.memray(threshold=MEMORY_LIMIT)
184 | def test_generate(test_cache_dir):
185 | """Test text generation.
186 |
187 | This test verifies that the Llamafile can generate text responses
188 | to a given prompt. It checks that the generated response is a
189 | non-empty string.
190 |
191 | :param test_cache_dir: Path to the test cache directory
192 | :type test_cache_dir: pathlib.Path
193 | """
194 | logger.info("Starting text generation test")
195 | llamafile = Llamafile(cache_dir=str(test_cache_dir))
196 |
197 | with llamafile:
198 | try:
199 | response = llamafile._generate("Hello, world!")
200 | assert isinstance(response, str), f"Response should be a string, got {type(response)}"
201 | assert len(response) > 0, "Response should not be empty"
202 | logger.info(f"Generated response: {response}")
203 | logger.info("Text generation test completed successfully")
204 | except Exception as e:
205 | logger.error(f"Error during text generation: {str(e)}")
206 | raise
207 |
208 |
209 | @pytest.mark.memray(threshold=MEMORY_LIMIT)
210 | def test_batch_generate(test_cache_dir):
211 | """Test batch text generation.
212 |
213 | This test verifies that the Llamafile can generate text responses
214 | for multiple prompts in a batch. It checks that the number of
215 | responses matches the number of prompts and that each response
216 | is a non-empty string.
217 |
218 | :param test_cache_dir: Path to the test cache directory
219 | :type test_cache_dir: pathlib.Path
220 | """
221 | logger.info("Starting batch text generation test")
222 | llamafile = Llamafile(cache_dir=str(test_cache_dir))
223 |
224 | with llamafile:
225 | try:
226 | prompts = ["Hello, world!", "How are you?", "What's the weather like?"]
227 | responses = llamafile._batch_generate(prompts)
228 | assert len(responses) == len(
229 | prompts
230 | ), f"Expected {len(prompts)} responses, got {len(responses)}"
231 | for i, response in enumerate(responses):
232 | assert isinstance(
233 | response, str
234 | ), f"Response {i+1} should be a string, got {type(response)}"
235 | assert len(response) > 0, f"Response {i+1} should not be empty"
236 | logger.info(f"Response {i+1}: {response}")
237 | logger.info("Batch text generation test completed successfully")
238 | except Exception as e:
239 | logger.error(f"Error during batch text generation: {str(e)}")
240 | raise
241 |
--------------------------------------------------------------------------------
/tests/unit/models/adapters/gpu.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import mock_open, patch
2 |
3 | import pytest
4 | import yaml
5 |
6 | from flow_judge.models.adapters.baseten.gpu import (
7 | _get_gpu_key,
8 | _has_gpu_key,
9 | _update_config,
10 | ensure_gpu,
11 | )
12 |
13 |
14 | @pytest.fixture
15 | def mock_env_setup(monkeypatch):
16 | """Mock the environment variables for testing.
17 |
18 | :returns: None
19 | :rtype: None
20 | """
21 | monkeypatch.delenv("BASETEN_GPU", raising=False)
22 |
23 |
24 | @pytest.mark.parametrize(
25 | "env_value, expected",
26 | [("H100", "H100"), ("h100", "h100"), ("A10G", "A10G"), ("a10g", "a10g"), (None, None)],
27 | )
28 | def test_get_gpu_key(mock_env_setup, env_value: str | None, expected: str | None, monkeypatch):
29 | """Test the _get_gpu_key function.
30 |
31 | :param mock_env_setup: Fixture to mock environment variables.
32 | :type mock_env_setup: None
33 | :param env_value: The value to set for the BASETEN_GPU environment variable.
34 | :type env_value: str | None
35 | :param expected: The expected output.
36 | :type expected: str | None
37 | :param monkeypatch: Pytest monkeypatch fixture.
38 | :type monkeypatch: pytest.MonkeyPatch
39 | :returns: None
40 | :rtype: None
41 | """
42 | if env_value is not None:
43 | monkeypatch.setenv("BASETEN_GPU", env_value)
44 | assert _get_gpu_key() == expected
45 |
46 |
47 | @pytest.mark.parametrize(
48 | "env_value, expected",
49 | [("H100", True), ("h100", True), ("A10G", True), ("a10g", True), (None, False)],
50 | )
51 | def test_has_gpu_key(mock_env_setup, env_value: str | None, expected: bool, monkeypatch):
52 | """Test the _has_gpu_key function.
53 |
54 | :param mock_env_setup: Fixture to mock environment variables.
55 | :type mock_env_setup: None
56 | :param env_value: The value to set for the BASETEN_GPU environment variable.
57 | :type env_value: str | None
58 | :param expected: The expected output.
59 | :type expected: bool
60 | :param monkeypatch: Pytest monkeypatch fixture.
61 | :type monkeypatch: pytest.MonkeyPatch
62 | :returns: None
63 | :rtype: None
64 | """
65 | if env_value is not None:
66 | monkeypatch.setenv("BASETEN_GPU", env_value)
67 | assert _has_gpu_key() == expected
68 |
69 |
70 | @pytest.mark.parametrize(
71 | "mock_file_contents, mock_env_value, expected",
72 | [
73 | (
74 | {
75 | "resources": {"accelerator": "test"},
76 | "repo_id": "test",
77 | "model_metadata": {"repo_id": "test"},
78 | },
79 | "H100",
80 | True,
81 | ),
82 | (
83 | {
84 | "resources": {"accelerator": "test"},
85 | "repo_id": "test",
86 | "model_metadata": {"repo_id": "test"},
87 | },
88 | "A10G",
89 | True,
90 | ),
91 | ({}, "H100", False),
92 | ({}, "A10G", False),
93 | ],
94 | )
95 | def test_update_config(monkeypatch, mock_file_contents: dict, mock_env_value: str, expected: bool):
96 | """Test the _update_config function.
97 |
98 | :param monkeypatch: Pytest monkeypatch fixture.
99 | :type monkeypatch: pytest.MonkeyPatch
100 | :param mock_file_contents: The contents to be mocked for the config.yaml file.
101 | :type mock_file_contents: dict
102 | :param mock_env_value: The value to set for the BASETEN_GPU environment variable.
103 | :type mock_env_value: str
104 | :param expected: The expected output.
105 | :type expected: bool
106 | :returns: None
107 | :rtype: None
108 | """
109 | monkeypatch.setenv("BASETEN_GPU", mock_env_value)
110 | mock_open_obj = mock_open(read_data=yaml.dump(mock_file_contents))
111 | with patch("builtins.open", mock_open_obj):
112 | assert _update_config() == expected
113 |
114 |
115 | @pytest.mark.parametrize(
116 | "mock_env_value, mock_input_value, expected",
117 | [
118 | ("H100", b"y\n", True),
119 | ("H100", b"n\n", True),
120 | (None, b"y\n", False),
121 | (None, b"n\n", False),
122 | (None, b"\n", False),
123 | (None, KeyboardInterrupt, False),
124 | ],
125 | )
126 | def test_ensure_gpu(monkeypatch, mock_env_value: str | None, mock_input_value, expected: bool):
127 | """Test the ensure_gpu function.
128 |
129 | :param monkeypatch: Pytest monkeypatch fixture.
130 | :type monkeypatch: pytest.MonkeyPatch
131 | :param mock_env_value: The value to set for the BASETEN_GPU environment variable.
132 | :type mock_env_value: str | None
133 | :param mock_input_value: The value to mock for user input.
134 | :type mock_input_value: bytes | KeyboardInterrupt
135 | :param expected: The expected output.
136 | :type expected: bool
137 | :returns: None
138 | :rtype: None
139 | """
140 | if mock_env_value is not None:
141 | monkeypatch.setenv("BASETEN_GPU", mock_env_value)
142 | with patch("builtins.input", side_effect=[mock_input_value]):
143 | with patch("flow_judge.models.adapters.baseten.gpu._update_config") as mock_update_config:
144 | mock_update_config.return_value = True
145 | assert ensure_gpu() == expected
146 |
--------------------------------------------------------------------------------
/tests/unit/models/adapters/validation.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import hmac
3 | from datetime import datetime, timedelta, timezone
4 | from unittest.mock import patch
5 |
6 | import pytest
7 |
8 | from flow_judge.models.adapters.baseten.validation import (
9 | TIMESTAMP_TOLERANCE_SECONDS,
10 | validate_baseten_signature,
11 | )
12 |
13 | # Mock test data for AsyncPredictResult
14 | mock_async_predict_result_data = {
15 | "model_config": {},
16 | "request_id": "mock_request_id",
17 | "model_id": "mock_model_id",
18 | "deployment_id": "mock_deployment_id",
19 | "type": "mock_type",
20 | "time": datetime.now(timezone.utc),
21 | "data": {"mock_data": "mock_value"},
22 | "errors": [{"mock_error": "mock_error_value"}],
23 | }
24 |
25 | # Mock webhook secret
26 | mock_webhook_secret = "mock_webhook_secret"
27 |
28 |
29 | @pytest.fixture(autouse=True)
30 | def mock_os_environ():
31 | """Mock os.environ for BASETEN_WEBHOOK_SECRET."""
32 | with patch.dict("os.environ", {"BASETEN_WEBHOOK_SECRET": mock_webhook_secret}):
33 | yield
34 |
35 |
36 | class TestValidateBasetenSignature:
37 | """Test suite for validate_baseten_signature."""
38 |
39 | def test_valid_signature(self):
40 | """Test for a valid Baseten signature."""
41 | with patch(
42 | "flow_judge.models.adapters.baseten.validation.AsyncPredictResult"
43 | ) as mock_async_predict_result:
44 | mock_async_predict_result_instance = mock_async_predict_result.return_value
45 | mock_async_predict_result_instance.model_dump_json.return_value = "mock_model_dump_json"
46 | mock_async_predict_result_instance.__dict__.update(mock_async_predict_result_data)
47 |
48 | # Mock valid signature
49 | valid_signature = hmac.digest(
50 | mock_webhook_secret.encode("utf-8"),
51 | b"mock_model_dump_json",
52 | hashlib.sha256,
53 | ).hex()
54 |
55 | assert (
56 | validate_baseten_signature(
57 | mock_async_predict_result_instance, f"v1={valid_signature}"
58 | )
59 | is True
60 | )
61 |
62 | def test_invalid_signature(self):
63 | """Test for an invalid Baseten signature."""
64 | with patch(
65 | "flow_judge.models.adapters.baseten.validation.AsyncPredictResult"
66 | ) as mock_async_predict_result:
67 | mock_async_predict_result_instance = mock_async_predict_result.return_value
68 | mock_async_predict_result_instance.model_dump_json.return_value = "mock_model_dump_json"
69 | mock_async_predict_result_instance.__dict__.update(mock_async_predict_result_data)
70 |
71 | # Mock invalid signature
72 | invalid_signature = "invalid_signature"
73 |
74 | assert (
75 | validate_baseten_signature(
76 | mock_async_predict_result_instance, f"v1={invalid_signature}"
77 | )
78 | is False
79 | )
80 |
81 | def test_stale_timestamp(self):
82 | """Test for a stale timestamp."""
83 | stale_timestamp = datetime.now(timezone.utc) - timedelta(
84 | seconds=TIMESTAMP_TOLERANCE_SECONDS + 1
85 | )
86 |
87 | with patch(
88 | "flow_judge.models.adapters.baseten.validation.AsyncPredictResult"
89 | ) as mock_async_predict_result:
90 | mock_async_predict_result_instance = mock_async_predict_result.return_value
91 | mock_async_predict_result_instance.model_dump_json.return_value = "mock_model_dump_json"
92 | mock_async_predict_result_instance.__dict__.update(mock_async_predict_result_data)
93 | mock_async_predict_result_instance.time = stale_timestamp
94 |
95 | # Mock valid signature
96 | valid_signature = hmac.digest(
97 | mock_webhook_secret.encode("utf-8"),
98 | b"mock_model_dump_json",
99 | hashlib.sha256,
100 | ).hex()
101 |
102 | assert (
103 | validate_baseten_signature(
104 | mock_async_predict_result_instance, f"v1={valid_signature}"
105 | )
106 | is False
107 | )
108 |
109 | def test_missing_webhook_secret(self, monkeypatch):
110 | """Test for a missing BASETEN_WEBHOOK_SECRET environment variable."""
111 | monkeypatch.delenv("BASETEN_WEBHOOK_SECRET", raising=False)
112 |
113 | with patch(
114 | "flow_judge.models.adapters.baseten.validation.AsyncPredictResult"
115 | ) as mock_async_predict_result:
116 | mock_async_predict_result_instance = mock_async_predict_result.return_value
117 | mock_async_predict_result_instance.model_dump_json.return_value = "mock_model_dump_json"
118 | mock_async_predict_result_instance.__dict__.update(mock_async_predict_result_data)
119 |
120 | # Mock valid signature
121 | valid_signature = hmac.digest(
122 | mock_webhook_secret.encode("utf-8"),
123 | b"mock_model_dump_json",
124 | hashlib.sha256,
125 | ).hex()
126 |
127 | assert (
128 | validate_baseten_signature(
129 | mock_async_predict_result_instance, f"v1={valid_signature}"
130 | )
131 | is False
132 | )
133 |
--------------------------------------------------------------------------------
/tests/unit/models/test_llamafile_unit.py:
--------------------------------------------------------------------------------
1 | import signal
2 | import subprocess
3 | from unittest.mock import MagicMock, patch
4 |
5 | import pytest
6 |
7 | from flow_judge.models.llamafile import cleanup_llamafile
8 |
9 | """
10 | Test suite for the cleanup_llamafile function.
11 |
12 | This module contains unit tests that verify the behavior of the cleanup_llamafile
13 | function under various scenarios, including normal termination, forced termination,
14 | error handling, and edge cases.
15 | """
16 |
17 |
18 | @pytest.fixture
19 | def mock_process():
20 | """Create a mock process for testing.
21 |
22 | This fixture creates a MagicMock object that simulates a process,
23 | primarily for use in testing the cleanup_llamafile function.
24 |
25 | :return: A MagicMock object representing a process
26 | :rtype: unittest.mock.MagicMock
27 |
28 | The mock process has the following attributes:
29 | - pid: Set to 12345
30 |
31 | Usage:
32 | def test_example(mock_process):
33 | # Use mock_process in your test
34 |
35 | Note:
36 | This fixture is session-scoped by default. If you need a different
37 | scope, you can modify the fixture decorator, e.g.,
38 | @pytest.fixture(scope="function").
39 | """
40 | process = MagicMock()
41 | process.pid = 12345
42 | return process
43 |
44 |
45 | @patch("os.getpgid")
46 | @patch("os.killpg")
47 | @patch("subprocess.Popen")
48 | def test_normal_termination(mock_popen, mock_killpg, mock_getpgid, mock_process):
49 | """Test the normal termination scenario of cleanup_llamafile.
50 |
51 | This test verifies that when a process terminates normally:
52 | 1. The process group ID is correctly retrieved.
53 | 2. A SIGTERM signal is sent to the process group.
54 | 3. The function waits for the process to terminate with a timeout.
55 |
56 | It uses mocking to simulate system calls and process behavior.
57 | """
58 | mock_getpgid.return_value = 54321
59 | mock_popen.return_value = mock_process
60 |
61 | cleanup_llamafile(lambda: mock_process)
62 |
63 | mock_getpgid.assert_called_once_with(12345)
64 | mock_killpg.assert_called_once_with(54321, signal.SIGTERM)
65 | mock_process.wait.assert_called_once_with(timeout=5)
66 |
67 |
68 | @patch("os.getpgid")
69 | @patch("os.killpg")
70 | @patch("subprocess.Popen")
71 | def test_force_kill(mock_popen, mock_killpg, mock_getpgid, mock_process):
72 | """Test the force kill scenario of cleanup_llamafile.
73 |
74 | This test ensures that when a process doesn't terminate after SIGTERM:
75 | 1. A SIGTERM signal is initially sent to the process group.
76 | 2. After a timeout, a SIGKILL signal is sent to forcefully terminate the process.
77 |
78 | It simulates a process that doesn't respond to SIGTERM, requiring SIGKILL.
79 | """
80 | mock_getpgid.return_value = 54321
81 | mock_process.wait.side_effect = subprocess.TimeoutExpired(cmd="test", timeout=5)
82 | mock_popen.return_value = mock_process
83 |
84 | cleanup_llamafile(lambda: mock_process)
85 |
86 | mock_killpg.assert_any_call(54321, signal.SIGTERM)
87 | mock_killpg.assert_any_call(54321, signal.SIGKILL)
88 |
89 |
90 | @patch("os.getpgid")
91 | @patch("subprocess.Popen")
92 | def test_os_error_fallback(mock_popen, mock_getpgid, mock_process):
93 | """Test the OS error fallback scenario of cleanup_llamafile.
94 |
95 | This test verifies the function's behavior when it encounters an OSError:
96 | 1. It attempts to get the process group ID, which raises an OSError.
97 | 2. The function falls back to terminating the individual process.
98 | 3. It waits for the process to terminate with a timeout.
99 |
100 | This test ensures the function has a proper fallback mechanism for OS-level errors.
101 | """
102 | mock_getpgid.side_effect = OSError()
103 | mock_popen.return_value = mock_process
104 |
105 | cleanup_llamafile(lambda: mock_process)
106 |
107 | mock_process.terminate.assert_called_once()
108 | mock_process.wait.assert_called_once_with(timeout=5)
109 |
110 |
111 | def test_already_terminated(mock_process):
112 | """Test the scenario where the process is already terminated.
113 |
114 | This test ensures that the cleanup_llamafile function handles the case
115 | where the process reference is None (indicating an already terminated process)
116 | without raising any errors.
117 |
118 | No assertions are made as the function should complete without any action or error.
119 | """
120 | mock_process.__bool__.return_value = False
121 |
122 | cleanup_llamafile(lambda: None)
123 |
124 | # No assertions needed, function should complete without error
125 |
--------------------------------------------------------------------------------
/tests/unit/test_flow_judge.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | from unittest.mock import patch
3 |
4 | import pytest
5 |
6 | from flow_judge.eval_data_types import EvalInput, EvalOutput
7 | from flow_judge.flow_judge import FlowJudge
8 | from flow_judge.metrics import RESPONSE_CORRECTNESS_BINARY, CustomMetric, RubricItem
9 | from flow_judge.models.common import BaseFlowJudgeModel
10 | from flow_judge.utils.prompt_formatter import USER_PROMPT_TEMPLATE, format_rubric, format_vars
11 |
12 |
13 | class MockFlowJudgeModel(BaseFlowJudgeModel):
14 | """Mock model for testing."""
15 |
16 | def __init__(self, model_id, model_type, generation_params):
17 | """Initialize the mock model."""
18 | super().__init__(model_id, model_type, generation_params)
19 |
20 | def _generate(self, prompt):
21 | """Generate a mock response."""
22 | return "Test feedback\n1"
23 |
24 | def _batch_generate(self, prompts, use_tqdm=True):
25 | """Generate mock responses for a list of prompts."""
26 | return ["Test feedback\n1" for _ in prompts]
27 |
28 | def generate(self, prompt):
29 | """Generate a mock response."""
30 | return self._generate(prompt)
31 |
32 | def batch_generate(self, prompts, use_tqdm=True):
33 | """Generate mock responses for a list of prompts."""
34 | return self._batch_generate(prompts, use_tqdm=use_tqdm)
35 |
36 |
37 | @pytest.fixture
38 | def mock_model():
39 | """Fixture to create a mock model for testing."""
40 | return MockFlowJudgeModel("test-model", "mock", {"temperature": 0.7})
41 |
42 |
43 | def test_flow_judge_initialization(mock_model):
44 | """Test the initialization of FlowJudge."""
45 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model)
46 | assert isinstance(judge, FlowJudge)
47 | assert judge.metric == RESPONSE_CORRECTNESS_BINARY
48 | assert judge.model == mock_model
49 |
50 |
51 | def test_flow_judge_initialization_invalid_metric():
52 | """Test FlowJudge initialization with invalid metric."""
53 | with pytest.raises(ValueError):
54 | FlowJudge(metric="invalid_metric", model=mock_model)
55 |
56 |
57 | def test_flow_judge_evaluate(mock_model):
58 | """Test the evaluate method of FlowJudge."""
59 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model)
60 | eval_input = EvalInput(
61 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}],
62 | output={"response": "Test response"},
63 | )
64 | result = judge.evaluate(eval_input)
65 | assert isinstance(result, EvalOutput)
66 | assert result.feedback == "Test feedback"
67 | assert result.score == 1
68 |
69 |
70 | def test_flow_judge_batch_evaluate(mock_model):
71 | """Test the batch_evaluate method of FlowJudge."""
72 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model)
73 | eval_inputs = [
74 | EvalInput(
75 | inputs=[{"query": "Test query 1"}, {"reference_answer": "Test reference 1"}],
76 | output={"response": "Test response 1"},
77 | ),
78 | EvalInput(
79 | inputs=[{"query": "Test query 2"}, {"reference_answer": "Test reference 2"}],
80 | output={"response": "Test response 2"},
81 | ),
82 | ]
83 | results = judge.batch_evaluate(eval_inputs, save_results=False)
84 | assert len(results) == 2
85 | for result in results:
86 | assert isinstance(result, EvalOutput)
87 | assert result.feedback == "Test feedback"
88 | assert result.score == 1
89 |
90 |
91 | @pytest.mark.parametrize("save_results", [True, False])
92 | def test_flow_judge_evaluate_save_results(mock_model, tmp_path, save_results):
93 | """Test saving results in the evaluate method."""
94 | judge = FlowJudge(
95 | metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model, output_dir=str(tmp_path)
96 | )
97 | eval_input = EvalInput(
98 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}],
99 | output={"response": "Test response"},
100 | )
101 | with patch("flow_judge.flow_judge.write_results_to_disk") as mock_write:
102 | judge.evaluate(eval_input, save_results=save_results)
103 | if save_results:
104 | mock_write.assert_called_once()
105 | else:
106 | mock_write.assert_not_called()
107 |
108 |
109 | def test_custom_metric():
110 | """Test creating and using a custom metric."""
111 | custom_metric = CustomMetric(
112 | name="custom_metric",
113 | criteria="Custom criteria",
114 | rubric=[RubricItem(score=0, description="Bad"), RubricItem(score=1, description="Good")],
115 | required_inputs=["custom_input"],
116 | required_output="custom_output",
117 | )
118 | assert custom_metric.name == "custom_metric"
119 | assert custom_metric.criteria == "Custom criteria"
120 | assert len(custom_metric.rubric) == 2
121 | assert custom_metric.required_inputs == ["custom_input"]
122 | assert custom_metric.required_output == "custom_output"
123 |
124 |
125 | def test_eval_input_validation(mock_model):
126 | """Test EvalInput validation."""
127 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model)
128 |
129 | # Valid input
130 | valid_input = EvalInput(
131 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}],
132 | output={"response": "Test response"},
133 | )
134 | assert judge.evaluate(valid_input)
135 |
136 | # Invalid input - missing required input
137 | invalid_input = EvalInput(
138 | inputs=[{"query": "Test query"}], output={"response": "Test response"}
139 | )
140 | with pytest.raises(ValueError):
141 | judge.evaluate(invalid_input)
142 |
143 | # Invalid input - wrong output key
144 | invalid_output = EvalInput(
145 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}],
146 | output={"wrong_key": "Test response"},
147 | )
148 | with pytest.raises(ValueError):
149 | judge.evaluate(invalid_output)
150 |
151 |
152 | def test_format_vars():
153 | """Test format_vars function."""
154 | variables = [{"question": "What is 2+2?"}, {"context": "Math basics"}]
155 | formatted = format_vars(variables)
156 | expected = """
157 | What is 2+2?
158 |
159 |
160 | Math basics
161 | """
162 | assert expected == formatted
163 |
164 |
165 | def test_format_rubric():
166 | """Test format_rubric function."""
167 | rubric = [RubricItem(score=1, description="Good"), RubricItem(score=0, description="Poor")]
168 | formatted = format_rubric(rubric)
169 | expected = """- Score 0: Poor
170 | - Score 1: Good"""
171 | assert expected == formatted
172 |
173 |
174 | def test_format_prompt(mock_model):
175 | """Test FlowJudge._format_prompt."""
176 | eval_input = EvalInput(
177 | inputs=[{"query": "Test query"}, {"reference_answer": "Test reference"}],
178 | output={"response": "Test response"},
179 | )
180 |
181 | judge = FlowJudge(metric=RESPONSE_CORRECTNESS_BINARY, model=mock_model)
182 | prompt = judge._format_prompt(eval_input)
183 |
184 | expected_prompt = USER_PROMPT_TEMPLATE.format(
185 | INPUTS=format_vars(eval_input.inputs),
186 | OUTPUT=format_vars([eval_input.output]),
187 | EVALUATION_CRITERIA=RESPONSE_CORRECTNESS_BINARY.criteria,
188 | RUBRIC=format_rubric(RESPONSE_CORRECTNESS_BINARY.rubric),
189 | )
190 | assert prompt == expected_prompt
191 |
192 |
193 | def test_eval_output_parse_fail_on_parse_error():
194 | """Test EvalOutput.parse with fail_on_parse_error."""
195 | # Invalid response without proper tags
196 | invalid_response = "This is an invalid response without proper tags"
197 |
198 | # Test with fail_on_parse_error=False (default behavior)
199 | result = EvalOutput.parse(invalid_response)
200 | assert isinstance(result, EvalOutput)
201 | assert result.feedback == "Error"
202 | assert result.score == -1
203 |
204 | # Test with fail_on_parse_error=True
205 | with pytest.raises(ValueError):
206 | EvalOutput.parse(invalid_response, fail_on_parse_error=True)
207 |
208 | # Test with valid response
209 | valid_response = "Good job!5"
210 | result = EvalOutput.parse(valid_response)
211 | assert isinstance(result, EvalOutput)
212 | assert result.feedback == "Good job!"
213 | assert result.score == 5
214 |
215 |
216 | @pytest.fixture(autouse=True)
217 | def cleanup(request, tmp_path):
218 | """Cleanup files and directories created during the test."""
219 | yield
220 | shutil.rmtree(tmp_path, ignore_errors=True)
221 |
--------------------------------------------------------------------------------
/tests/unit/test_metrics.py:
--------------------------------------------------------------------------------
1 | from flow_judge.metrics import RESPONSE_CORRECTNESS_BINARY, CustomMetric, RubricItem
2 |
3 |
4 | def test_response_correctness_binary():
5 | """Test the RESPONSE_CORRECTNESS_BINARY metric."""
6 | metric = RESPONSE_CORRECTNESS_BINARY
7 | assert metric.name == "Response Correctness (Binary)"
8 | assert len(metric.rubric) == 2
9 | assert metric.required_inputs == ["query", "reference_answer"]
10 | assert metric.required_output == "response"
11 |
12 |
13 | def test_custom_metric():
14 | """Test the CustomMetric class."""
15 | custom_metric = CustomMetric(
16 | name="Test Metric",
17 | criteria="Test criteria",
18 | rubric=[RubricItem(score=0, description="Bad"), RubricItem(score=1, description="Good")],
19 | required_inputs=["test_input"],
20 | required_output="test_output",
21 | )
22 | assert custom_metric.name == "Test Metric"
23 | assert custom_metric.criteria == "Test criteria"
24 | assert len(custom_metric.rubric) == 2
25 | assert custom_metric.required_inputs == ["test_input"]
26 | assert custom_metric.required_output == "test_output"
27 |
--------------------------------------------------------------------------------
/tests/unit/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from flow_judge.eval_data_types import EvalInput
4 | from flow_judge.metrics import CustomMetric, RubricItem
5 | from flow_judge.utils.prompt_formatter import (
6 | USER_PROMPT_NO_INPUTS_TEMPLATE,
7 | USER_PROMPT_TEMPLATE,
8 | format_rubric,
9 | format_user_prompt,
10 | format_vars,
11 | )
12 | from flow_judge.utils.validators import validate_eval_input
13 |
14 |
15 | def test_format_vars():
16 | """Test the format_vars function."""
17 | variables = [{"question": "What is 2+2?"}, {"context": "Math basics"}]
18 | formatted = format_vars(variables)
19 | expected = "\nWhat is 2+2?\n\n\nMath basics\n"
20 | assert formatted == expected
21 |
22 |
23 | def test_format_rubric():
24 | """Test the format_rubric function."""
25 | rubric = [RubricItem(score=0, description="Bad"), RubricItem(score=1, description="Good")]
26 | formatted = format_rubric(rubric)
27 | expected = "- Score 0: Bad\n- Score 1: Good"
28 | assert formatted == expected
29 |
30 |
31 | def test_format_user_prompt():
32 | """Test the format_user_prompt function."""
33 | # Test with inputs
34 | variables_with_inputs = {
35 | "INPUTS": "Test input",
36 | "OUTPUT": "Test output",
37 | "EVALUATION_CRITERIA": "Test criteria",
38 | "RUBRIC": "Test rubric",
39 | }
40 | formatted_with_inputs = format_user_prompt(variables_with_inputs)
41 | assert USER_PROMPT_TEMPLATE.format(**variables_with_inputs) == formatted_with_inputs
42 |
43 | # Test without inputs
44 | variables_without_inputs = {
45 | "INPUTS": "",
46 | "OUTPUT": "Test output",
47 | "EVALUATION_CRITERIA": "Test criteria",
48 | "RUBRIC": "Test rubric",
49 | }
50 | formatted_without_inputs = format_user_prompt(variables_without_inputs)
51 | assert (
52 | USER_PROMPT_NO_INPUTS_TEMPLATE.format(**variables_without_inputs)
53 | == formatted_without_inputs
54 | )
55 |
56 |
57 | def test_validate_eval_input():
58 | """Test the validate_eval_input function."""
59 | metric = CustomMetric(
60 | name="Test Metric",
61 | criteria="Test criteria",
62 | rubric=[RubricItem(score=0, description="Bad"), RubricItem(score=1, description="Good")],
63 | required_inputs=["test_input"],
64 | required_output="test_output",
65 | )
66 | valid_input = EvalInput(
67 | inputs=[{"test_input": "Test value"}], output={"test_output": "Test output"}
68 | )
69 | validate_eval_input(valid_input, metric) # Should not raise an exception
70 |
71 | invalid_input = EvalInput(
72 | inputs=[{"wrong_input": "Test value"}], output={"test_output": "Test output"}
73 | )
74 | with pytest.raises(ValueError):
75 | validate_eval_input(invalid_input, metric)
76 |
--------------------------------------------------------------------------------