├── docs
├── .nojekyll
├── _navbar.md
├── _coverpage.md
├── _sidebar.md
├── index.html
├── README.md
├── installation.md
├── PLANNING_WRAPPER.md
├── quick-start.md
├── benchmarking.md
├── algorithms.md
└── development.md
├── its_hub
├── integration
│ ├── __init__.py
│ └── reward_hub.py
├── __init__.py
├── algorithms
│ ├── __init__.py
│ ├── bon.py
│ └── beam_search.py
├── utils.py
├── base.py
├── types.py
└── error_handling.py
├── tests
├── mocks
│ ├── __init__.py
│ ├── reward_models.py
│ ├── test_data.py
│ └── language_models.py
├── test_particle_gibbs_resampling.py
├── conftest.py
└── test_reward_hub_integration.py
├── .claude
└── settings.json
├── .jupytext.yml
├── ruff.toml
├── .github
└── workflows
│ ├── tests.yaml
│ ├── sync-notebooks.yaml
│ └── release.yaml
├── .devcontainer
├── devcontainer.json
├── Dockerfile
└── init-firewall.sh
├── scripts
└── test_math_example.py
├── pyproject.toml
├── README.md
├── .gitignore
├── notebooks
├── self-consistency.py
└── self-consistency.ipynb
├── CLAUDE.md
└── LICENSE
/docs/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/its_hub/integration/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/mocks/__init__.py:
--------------------------------------------------------------------------------
1 | """Mock objects for testing."""
2 |
--------------------------------------------------------------------------------
/.claude/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "permissions": {
3 | "allow": [
4 | "Bash(rg:*)"
5 | ],
6 | "deny": []
7 | }
8 | }
--------------------------------------------------------------------------------
/.jupytext.yml:
--------------------------------------------------------------------------------
1 | # Jupytext configuration
2 | formats: "py:percent,ipynb"
3 | notebook_metadata_filter: "all"
4 | cell_metadata_filter: "-all"
--------------------------------------------------------------------------------
/its_hub/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | A Python library for inference-time scaling LLMs
3 | """
4 |
5 | from importlib.metadata import version
6 |
7 | __version__ = version("its_hub")
8 |
--------------------------------------------------------------------------------
/docs/_navbar.md:
--------------------------------------------------------------------------------
1 | - Links
2 | - [GitHub](https://github.com/Red-Hat-AI-Innovation-Team/its_hub)
3 | - [PyPI](https://pypi.org/project/its_hub/)
4 | - [Tests](https://github.com/Red-Hat-AI-Innovation-Team/its_hub/actions/workflows/tests.yml)
5 | - [Coverage](https://codecov.io/gh/Red-Hat-AI-Innovation-Team/its_hub)
--------------------------------------------------------------------------------
/docs/_coverpage.md:
--------------------------------------------------------------------------------
1 | # its-hub
2 |
3 | > A Python library for inference-time scaling LLMs
4 |
5 | - 🔬 Multiple scaling algorithms (Particle Filtering, Best-of-N, Beam Search, Self-Consistency)
6 | - 🚀 OpenAI-compatible API with Inference-as-a-Service (IaaS)
7 | - ⚡ Async generation with concurrency limits and error handling
8 | - 📊 Comprehensive benchmarking tools
9 |
10 | [GitHub](https://github.com/Red-Hat-AI-Innovation-Team/its_hub)
11 | [Get Started](quick-start.md)
12 |
--------------------------------------------------------------------------------
/docs/_sidebar.md:
--------------------------------------------------------------------------------
1 | - [Overview](README.md)
2 | - [Installation](installation.md)
3 | - [Quick Start Guide](quick-start.md)
4 | - [IaaS Service Guide](iaas-service.md)
5 | - [Algorithms](algorithms.md)
6 | - [Particle Filtering](algorithms.md#particle-filtering)
7 | - [Best-of-N](algorithms.md#best-of-n)
8 | - [Beam Search](algorithms.md#beam-search)
9 | - [Self-Consistency](algorithms.md#self-consistency)
10 | - [Benchmarking](benchmarking.md)
11 | - [Development](development.md)
--------------------------------------------------------------------------------
/ruff.toml:
--------------------------------------------------------------------------------
1 | line-length = 88
2 | target-version = "py311"
3 | extend-exclude = ["_version.py"]
4 |
5 | [lint]
6 | select = [
7 | "E", # pycodestyle errors
8 | "W", # pycodestyle warnings
9 | "F", # pyflakes
10 | "I", # isort
11 | "N", # pep8-naming
12 | "UP", # pyupgrade
13 | "B", # flake8-bugbear
14 | "C4", # flake8-comprehensions
15 | "SIM", # flake8-simplify
16 | "TID", # flake8-tidy-imports
17 | "RUF", # Ruff-specific rules
18 | ]
19 | ignore = [
20 | "E501", # line too long (handled by formatter)
21 | "B008", # do not perform function calls in argument defaults
22 | "B905", # `zip()` without an explicit `strict=` parameter
23 | ]
24 |
25 | [lint.per-file-ignores]
26 | "__init__.py" = ["F401"] # unused imports in __init__ files
27 | "tests/*" = ["S101"] # use of assert in tests
28 |
29 | [lint.isort]
30 | known-first-party = ["its_hub"]
31 |
32 | [format]
33 | quote-style = "double"
34 | indent-style = "space"
35 | skip-magic-trailing-comma = false
36 | line-ending = "auto"
--------------------------------------------------------------------------------
/.github/workflows/tests.yaml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | code-quality:
11 | runs-on: ubuntu-latest
12 | continue-on-error: true
13 | steps:
14 | - uses: actions/checkout@v4
15 |
16 | - name: Install uv
17 | uses: astral-sh/setup-uv@v4
18 | with:
19 | version: "latest"
20 |
21 | - name: Set up Python
22 | run: uv python install 3.11
23 |
24 | - name: Install dependencies
25 | run: uv sync --extra dev
26 |
27 | - name: Run linting checks
28 | run: uv run ruff check its_hub/
29 |
30 | - name: Run formatting checks
31 | run: uv run ruff format --check its_hub/
32 |
33 | test:
34 | runs-on: ubuntu-latest
35 | strategy:
36 | matrix:
37 | python-version: ["3.10", "3.11", "3.12"]
38 |
39 | steps:
40 | - uses: actions/checkout@v4
41 |
42 | - name: Install uv
43 | uses: astral-sh/setup-uv@v4
44 | with:
45 | version: "latest"
46 |
47 | - name: Set up Python
48 | run: uv python install ${{ matrix.python-version }}
49 |
50 | - name: Install dependencies
51 | run: uv sync --extra dev
52 |
53 | - name: Run tests
54 | run: uv run pytest tests/ --cov=its_hub --cov-report=xml
55 |
56 | - name: Upload coverage to Codecov
57 | uses: codecov/codecov-action@v4
58 | with:
59 | token: ${{ secrets.CODECOV_TOKEN }}
60 | file: ./coverage.xml
61 | fail_ci_if_error: false
62 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Claude Code Sandbox",
3 | "build": {
4 | "dockerfile": "Dockerfile",
5 | "args": {
6 | "TZ": "${localEnv:TZ:America/New_York}"
7 | }
8 | },
9 | "runArgs": [
10 | "--cap-add=NET_ADMIN",
11 | "--cap-add=NET_RAW"
12 | ],
13 | "customizations": {
14 | "vscode": {
15 | "extensions": [
16 | "dbaeumer.vscode-eslint",
17 | "esbenp.prettier-vscode",
18 | "eamodio.gitlens",
19 | "ms-python.python"
20 | ],
21 | "settings": {
22 | "editor.formatOnSave": true,
23 | "editor.defaultFormatter": "esbenp.prettier-vscode",
24 | "editor.codeActionsOnSave": {
25 | "source.fixAll.eslint": "explicit"
26 | },
27 | "terminal.integrated.defaultProfile.linux": "zsh",
28 | "terminal.integrated.profiles.linux": {
29 | "bash": {
30 | "path": "bash",
31 | "icon": "terminal-bash"
32 | },
33 | "zsh": {
34 | "path": "zsh"
35 | }
36 | }
37 | }
38 | }
39 | },
40 | "remoteUser": "node",
41 | "mounts": [
42 | "source=claude-code-bashhistory,target=/commandhistory,type=volume",
43 | "source=claude-code-config,target=/home/node/.claude,type=volume"
44 | ],
45 | "remoteEnv": {
46 | "NODE_OPTIONS": "--max-old-space-size=4096",
47 | "CLAUDE_CONFIG_DIR": "/home/node/.claude",
48 | "POWERLEVEL9K_DISABLE_GITSTATUS": "true"
49 | },
50 | "workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
51 | "workspaceFolder": "/workspace",
52 | "postCreateCommand": "sudo /usr/local/bin/init-firewall.sh"
53 | }
54 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | its_hub - Inference-Time Scaling for LLMs
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/.github/workflows/sync-notebooks.yaml:
--------------------------------------------------------------------------------
1 | name: Sync Notebooks
2 |
3 | on:
4 | pull_request:
5 | branches: [ main ]
6 | paths: [ 'notebooks/**/*.py', 'notebooks/**/*.ipynb' ]
7 |
8 | jobs:
9 | sync-notebooks:
10 | runs-on: ubuntu-latest
11 | permissions:
12 | contents: write
13 |
14 | steps:
15 | - name: Checkout repository
16 | uses: actions/checkout@v4
17 | with:
18 | fetch-depth: 0
19 | token: ${{ secrets.GITHUB_TOKEN }}
20 |
21 | - name: Set up Python
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: '3.11'
25 |
26 | - name: Install dependencies
27 | run: |
28 | pip install jupytext jupyter
29 |
30 | - name: Sync Python files to notebooks
31 | run: |
32 | # Convert .py files to .ipynb
33 | find . -name "*.py" -path "*/notebooks/*" -exec jupytext --to ipynb {} \;
34 |
35 | - name: Sync notebooks to Python files
36 | run: |
37 | # Convert .ipynb files to .py (in case someone committed a notebook directly)
38 | find . -name "*.ipynb" -path "*/notebooks/*" -exec jupytext --to py:percent {} \;
39 |
40 | - name: Check for changes
41 | id: verify-changed-files
42 | run: |
43 | if [ -n "$(git status --porcelain)" ]; then
44 | echo "changed=true" >> $GITHUB_OUTPUT
45 | else
46 | echo "changed=false" >> $GITHUB_OUTPUT
47 | fi
48 |
49 | - name: Commit and push changes to PR branch
50 | if: steps.verify-changed-files.outputs.changed == 'true'
51 | run: |
52 | git config --local user.email "action@github.com"
53 | git config --local user.name "GitHub Action"
54 | git add .
55 | git commit -m "Auto-sync notebooks and Python files" || exit 0
56 | git push origin HEAD:${{ github.head_ref }}
--------------------------------------------------------------------------------
/its_hub/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | from .beam_search import BeamSearch, BeamSearchResult
2 | from .bon import BestOfN, BestOfNResult
3 | from .particle_gibbs import (
4 | EntropicParticleFiltering,
5 | ParticleFiltering,
6 | ParticleFilteringResult,
7 | ParticleGibbs,
8 | ParticleGibbsResult,
9 | )
10 | from .self_consistency import SelfConsistency, SelfConsistencyResult
11 |
12 | __all__ = [
13 | "BeamSearch",
14 | "BeamSearchResult",
15 | "BestOfN",
16 | "BestOfNResult",
17 | "EntropicParticleFiltering",
18 | "MetropolisHastings",
19 | "MetropolisHastingsResult",
20 | "ParticleFiltering",
21 | "ParticleFilteringResult",
22 | "ParticleGibbs",
23 | "ParticleGibbsResult",
24 | "SelfConsistency",
25 | "SelfConsistencyResult",
26 | ]
27 |
28 | ###
29 |
30 | from typing import Union
31 |
32 | from its_hub.base import (
33 | AbstractLanguageModel,
34 | AbstractOutcomeRewardModel,
35 | AbstractScalingAlgorithm,
36 | AbstractScalingResult,
37 | )
38 | from its_hub.lms import StepGeneration
39 | from its_hub.types import ChatMessage, ChatMessages
40 |
41 |
42 | class MetropolisHastingsResult(AbstractScalingResult):
43 | pass
44 |
45 |
46 | class MetropolisHastings(AbstractScalingAlgorithm):
47 | def __init__(
48 | self, step_generation: StepGeneration, orm: AbstractOutcomeRewardModel
49 | ):
50 | self.step_generation = step_generation
51 | self.orm = orm
52 |
53 | def infer(
54 | self,
55 | lm: AbstractLanguageModel,
56 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
57 | budget: int,
58 | show_progress: bool = False,
59 | return_response_only: bool = True,
60 | ) -> str | MetropolisHastingsResult:
61 | # TODO: Implement Metropolis-Hastings algorithm
62 | # Will need to convert prompt_or_messages to ChatMessages format when implemented
63 | raise NotImplementedError("Metropolis-Hastings algorithm not yet implemented")
64 |
--------------------------------------------------------------------------------
/scripts/test_math_example.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Example script demonstrating the use of its_hub for math problem solving.
4 | This script tests the Qwen math model with various mathematical problems
5 | using particle filtering for improved solution quality.
6 | """
7 |
8 | import os
9 |
10 | from its_hub.algorithms import ParticleFiltering
11 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
12 | from its_hub.lms import OpenAICompatibleLanguageModel, StepGeneration
13 | from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT
14 |
15 |
16 | def main():
17 | # Get GPU ID from environment variable or default to 0
18 | gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
19 |
20 | # Initialize the language model
21 | # Note: The endpoint port (8100) must match the port used when starting the vLLM server
22 | lm = OpenAICompatibleLanguageModel(
23 | endpoint="http://localhost:8100/v1", # Make sure this matches your vLLM server port
24 | api_key="NO_API_KEY",
25 | model_name="Qwen/Qwen2.5-Math-1.5B-Instruct",
26 | system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT,
27 | )
28 |
29 | # Test prompts
30 | test_prompts = [
31 | "What is 2+2? Show your steps.",
32 | "Solve the quadratic equation x^2 + 5x + 6 = 0. Show your steps.",
33 | "Find the derivative of f(x) = x^2 + 3x + 2. Show your steps.",
34 | "Let a be a positive real number such that all the roots of x^3 + ax^2 + ax + 1 = 0 are real. Find the smallest possible value of a.",
35 | ]
36 |
37 | # Initialize step generation and reward model
38 | sg = StepGeneration(step_token="\n\n", max_steps=32, stop_token=r"\boxed")
39 | prm = LocalVllmProcessRewardModel(
40 | model_name="Qwen/Qwen2.5-Math-PRM-7B",
41 | device=f"cuda:{gpu_id}", # Use the same GPU as the vLLM server
42 | aggregation_method="prod",
43 | )
44 | scaling_alg = ParticleFiltering(sg, prm)
45 |
46 | # Run tests
47 | print("Testing Qwen Math Model with different approaches...")
48 | print(f"Using GPU {gpu_id} with memory optimization settings\n")
49 |
50 | for prompt in test_prompts:
51 | print(f"\nTesting: {prompt}")
52 | print("Response:", scaling_alg.infer(lm, prompt, budget=8))
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # its-hub
2 |
3 | [](https://github.com/Red-Hat-AI-Innovation-Team/its_hub/actions/workflows/tests.yml)
4 | [](https://codecov.io/gh/Red-Hat-AI-Innovation-Team/its_hub)
5 |
6 | **its-hub** provides inference-time scaling for LLMs through multiple approaches:
7 |
8 | 1. **Direct Library Usage** - For Python integration
9 | 2. **Inference-as-a-Service (IaaS) API** - OpenAI-compatible HTTP API (⚠️ Alpha)
10 |
11 | ## What is Inference-Time Scaling?
12 |
13 | Inference-time scaling improves LLM performance by using computational resources during inference to generate better responses. Unlike training-time scaling which requires more parameters or training data, inference-time scaling algorithms can improve any pre-trained model's performance by:
14 |
15 | - **Generating multiple candidate responses** and selecting the best one
16 | - **Using step-by-step reasoning** with reward models to guide generation
17 | - **Applying probabilistic methods** like particle filtering for better exploration
18 |
19 | ## Key Features
20 |
21 | - 🔬 **Multiple Algorithms**: Particle Filtering, Best-of-N, Beam Search, Self-Consistency
22 | - 🚀 **OpenAI-Compatible API**: Easy integration with existing applications
23 | - 🧮 **Math-Optimized**: Built for mathematical reasoning with specialized prompts and evaluation
24 | - 📊 **Benchmarking Tools**: Compare algorithms on standard datasets like MATH500 and AIME-2024
25 | - ⚡ **Async Support**: Concurrent generation with limits and error handling
26 |
27 | ## Supported Algorithms
28 |
29 | | Algorithm | Budget Interpretation | Snippet |
30 | |-----------|----------------------|---------|
31 | | **Self-Consistency** | Number of parallel generations | `SelfConsistency()` |
32 | | **Best-of-N** | Number of candidates to generate | `BestOfN(rm)` |
33 | | **Beam Search** | Total generations ÷ beam width | `BeamSearch(sg, prm, beam_width=4)` |
34 | | **Particle Filtering** | Number of particles to maintain | `ParticleFiltering(sg, prm)` |
35 | | **Planning Enhancement** | Enhances any algorithm with planning | `PlanningWrapper(base_algorithm)` |
36 |
37 | ### Planning Enhancement
38 |
39 | The **PlanningWrapper** can enhance any ITS algorithm with a planning phase that generates multiple solution approaches before execution. See [PLANNING_WRAPPER.md](PLANNING_WRAPPER.md) for detailed documentation.
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | [build-system]
4 | requires = ["setuptools>=42", "wheel", "setuptools_scm>=8.0.0"]
5 | build-backend = "setuptools.build_meta"
6 |
7 | [project]
8 | name = "its_hub"
9 | description = "A Python library for inference-time scaling LLMs"
10 | authors = [
11 | {name = "Kai Xu and the Red Hat AI Innovation Team", email = "xuk@redhat.com"}
12 | ]
13 | readme = "README.md"
14 | requires-python = ">=3.10"
15 | license = "Apache-2.0"
16 | classifiers = [
17 | "Programming Language :: Python :: 3",
18 | "Programming Language :: Python :: 3.11",
19 | "Programming Language :: Python :: 3.12",
20 | "Operating System :: OS Independent",
21 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
22 | ]
23 | dynamic = ["version"]
24 | dependencies = [
25 | "openai>=1.68.2",
26 | "tqdm>=4.65.0",
27 | "typing-extensions>=4.12.2",
28 | "reward-hub>=0.1.8",
29 | "transformers>=4.53.2",
30 | "backoff>=2.2.0",
31 | "click>=8.1.0",
32 | "fastapi>=0.115.5",
33 | "pydantic>=2.7.2",
34 | "numpy",
35 | "uvicorn",
36 | "requests",
37 | "aiohttp>=3.9.0",
38 | "litellm>=1.70.0,<1.75.0"
39 | ]
40 |
41 | [project.scripts]
42 | its-iaas = "its_hub.integration.iaas:main"
43 |
44 | [tool.setuptools_scm]
45 | version_file = "its_hub/_version.py"
46 | # do not include +gREV local version, required for Test PyPI upload
47 | local_scheme = "no-local-version"
48 |
49 | [project.urls]
50 | Homepage = "https://ai-innovation.team/its_hub"
51 |
52 | [project.optional-dependencies]
53 |
54 | vllm = [
55 | "reward-hub[vllm]>=0.1.8",
56 | ]
57 |
58 | dev = [
59 | "pytest>=7.0.0",
60 | "pytest-asyncio>=0.21.0",
61 | "pytest-cov>=4.1.0",
62 | "ruff>=0.10.0",
63 | "jupytext>=1.15.0",
64 | "jupyter>=1.0.0",
65 | "reward-hub[vllm]>=0.1.8",
66 | ]
67 |
68 | prm = [
69 | "reward-hub[prm]>=0.1.8"
70 | ]
71 |
72 | research = [
73 | "math-verify>=0.1.0", # For mathematical reasoning evaluation in benchmark scripts
74 | "datasets>=2.0.0", # For loading benchmark datasets (MATH500, AIME)
75 | "matplotlib>=3.5.0", # For visualization scripts
76 | ]
77 | cloud = [
78 | "boto3>=1.28.0", # For AWS Bedrock support
79 | "google-cloud-aiplatform>=1.38.0", # For Vertex AI support
80 | ]
81 |
82 | [tool.setuptools]
83 | package-dir = {"" = "."}
84 |
85 | [tool.setuptools.packages.find]
86 | where = ["."]
87 | include = [
88 | "its_hub",
89 | "its_hub.algorithms",
90 | "its_hub.integration",
91 | ]
92 |
93 |
--------------------------------------------------------------------------------
/.devcontainer/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM node:20
2 |
3 | ARG TZ
4 | ENV TZ="$TZ"
5 |
6 | # Install basic development tools and iptables/ipset
7 | RUN apt update && apt install -y less \
8 | git \
9 | procps \
10 | sudo \
11 | fzf \
12 | zsh \
13 | man-db \
14 | unzip \
15 | gnupg2 \
16 | gh \
17 | iptables \
18 | ipset \
19 | iproute2 \
20 | dnsutils \
21 | aggregate \
22 | jq \
23 | curl
24 |
25 | # Ensure default node user has access to /usr/local/share
26 | RUN mkdir -p /usr/local/share/npm-global && \
27 | chown -R node:node /usr/local/share
28 |
29 | ARG USERNAME=node
30 |
31 | # Persist bash history.
32 | RUN SNIPPET="export PROMPT_COMMAND='history -a' && export HISTFILE=/commandhistory/.bash_history" \
33 | && mkdir /commandhistory \
34 | && touch /commandhistory/.bash_history \
35 | && chown -R $USERNAME /commandhistory
36 |
37 | # Set `DEVCONTAINER` environment variable to help with orientation
38 | ENV DEVCONTAINER=true
39 |
40 | # Create workspace and config directories and set permissions
41 | RUN mkdir -p /workspace /home/node/.claude && \
42 | chown -R node:node /workspace /home/node/.claude
43 |
44 | WORKDIR /workspace
45 |
46 | RUN ARCH=$(dpkg --print-architecture) && \
47 | wget "https://github.com/dandavison/delta/releases/download/0.18.2/git-delta_0.18.2_${ARCH}.deb" && \
48 | sudo dpkg -i "git-delta_0.18.2_${ARCH}.deb" && \
49 | rm "git-delta_0.18.2_${ARCH}.deb"
50 |
51 | # Set up non-root user
52 | USER node
53 |
54 | # Install global packages
55 | ENV NPM_CONFIG_PREFIX=/usr/local/share/npm-global
56 | ENV PATH=$PATH:/usr/local/share/npm-global/bin
57 |
58 | # Set the default shell to zsh rather than sh
59 | ENV SHELL=/bin/zsh
60 |
61 | # Default powerline10k theme
62 | RUN sh -c "$(wget -O- https://github.com/deluan/zsh-in-docker/releases/download/v1.2.0/zsh-in-docker.sh)" -- \
63 | -p git \
64 | -p fzf \
65 | -a "source /usr/share/doc/fzf/examples/key-bindings.zsh" \
66 | -a "source /usr/share/doc/fzf/examples/completion.zsh" \
67 | -a "export PROMPT_COMMAND='history -a' && export HISTFILE=/commandhistory/.bash_history" \
68 | -x
69 |
70 | # Install uv for Python package management
71 | RUN curl -LsSf https://astral.sh/uv/install.sh | sh
72 | ENV PATH="/home/node/.cargo/bin:$PATH"
73 |
74 | # Install Claude
75 | RUN npm install -g @anthropic-ai/claude-code
76 |
77 | # Copy and set up firewall script
78 | COPY init-firewall.sh /usr/local/bin/
79 | USER root
80 | RUN chmod +x /usr/local/bin/init-firewall.sh && \
81 | echo "node ALL=(root) NOPASSWD: /usr/local/bin/init-firewall.sh" > /etc/sudoers.d/node-firewall && \
82 | chmod 0440 /etc/sudoers.d/node-firewall
83 | USER node
84 |
--------------------------------------------------------------------------------
/its_hub/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | # the system prompt for step-by-step reasoning taken from https://github.com/huggingface/search-and-learn
4 | SAL_STEP_BY_STEP_SYSTEM_PROMPT = "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n## Step 1: [Concise description]\n[Brief explanation and calculations]\n\n## Step 2: [Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem."
5 |
6 | QWEN_SYSTEM_PROMPT = (
7 | "Please reason step by step, and put your final answer within \\boxed{}."
8 | )
9 |
10 |
11 | def extract_content_from_lm_response(message: dict) -> str:
12 | """
13 | Extract content from a single LM response message object.
14 |
15 | Args:
16 | message: A message dict returned by fetch_single_response.
17 |
18 | Returns:
19 | The content string. If the message contains tool calls, returns the content
20 | if available, otherwise returns an empty string.
21 | """
22 | # TODO: This conversion to text is not ideal as it involves manually formatting
23 | # tool calls and neglects images in multi-modal content. Consider refactoring
24 | # to work with structured message objects instead of flattening to strings.
25 |
26 | # Extract text content (handle both string and list[dict] formats)
27 | raw_content = message.get("content")
28 |
29 | if raw_content is None:
30 | content = ""
31 | elif isinstance(raw_content, str):
32 | content = raw_content
33 | elif isinstance(raw_content, list):
34 | # Multi-modal content: extract text parts (images are ignored)
35 | text_parts = [
36 | item.get("text", "")
37 | for item in raw_content
38 | if isinstance(item, dict) and item.get("type") == "text"
39 | ]
40 | content = " ".join(text_parts)
41 | else:
42 | raise ValueError(
43 | f"Invalid content type: {type(raw_content)}, expected str, list[dict], or None"
44 | )
45 |
46 | # If there are tool calls, add tool-calls to the content
47 | if message.get("tool_calls"):
48 | tool_calls = message.get("tool_calls", [])
49 | tool_descriptions = []
50 | for tc in tool_calls:
51 | if isinstance(tc, dict) and "function" in tc:
52 | func = tc["function"]
53 | func_name = func.get("name", "unknown")
54 | tool_descriptions.append(
55 | f"[Tool call: {func_name} Tool args: {json.dumps(func.get('arguments', {}))}]"
56 | )
57 | else:
58 | raise ValueError(
59 | f"Invalid tool call: {tc}, expected a dict with a 'function' key"
60 | )
61 | content += " ".join(tool_descriptions)
62 |
63 | return content
64 |
--------------------------------------------------------------------------------
/its_hub/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from .types import ChatMessage, ChatMessages
4 |
5 |
6 | class AbstractLanguageModel(ABC):
7 | """abstract base class for (autoregressive) language models"""
8 |
9 | @abstractmethod
10 | async def agenerate(
11 | self,
12 | messages: list[ChatMessage] | list[list[ChatMessage]],
13 | stop: str | None = None,
14 | ) -> str | list[str]:
15 | """generate a response from the model asynchronously"""
16 | pass
17 |
18 | @abstractmethod
19 | def generate(
20 | self,
21 | messages: list[ChatMessage] | list[list[ChatMessage]],
22 | stop: str | None = None,
23 | ) -> str | list[str]:
24 | """generate a response from the model synchronously"""
25 | pass
26 |
27 | def evaluate(self, prompt: str, generation: str) -> list[float]:
28 | """evaluate the likelihoods of the generation synchronously"""
29 | raise NotImplementedError("evaluate method not implemented")
30 |
31 |
32 | class AbstractScalingResult(ABC):
33 | """abstract base class for scaling result"""
34 |
35 | @property
36 | @abstractmethod
37 | def the_one(self) -> str:
38 | """the selected response"""
39 | pass
40 |
41 |
42 | class AbstractScalingAlgorithm(ABC):
43 | """abstract base class for inference-time scaling algorithms"""
44 |
45 | @abstractmethod
46 | async def ainfer(
47 | self,
48 | lm: AbstractLanguageModel,
49 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
50 | budget: int,
51 | return_response_only: bool = True,
52 | tools: list[dict] | None = None,
53 | tool_choice: str | dict | None = None,
54 | ) -> str | AbstractScalingResult:
55 | """
56 | Run inference asynchronously with the given language model and prompt.
57 |
58 | This is the primary method that subclasses must implement.
59 | """
60 | pass
61 |
62 | def infer(
63 | self,
64 | lm: AbstractLanguageModel,
65 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
66 | budget: int,
67 | return_response_only: bool = True,
68 | tools: list[dict] | None = None,
69 | tool_choice: str | dict | None = None,
70 | ) -> str | AbstractScalingResult:
71 | """
72 | Run inference synchronously with the given language model and prompt.
73 |
74 | Default implementation wraps ainfer() using asyncio.run().
75 | """
76 | import asyncio
77 |
78 | return asyncio.run(
79 | self.ainfer(
80 | lm, prompt_or_messages, budget, return_response_only, tools, tool_choice
81 | )
82 | )
83 |
84 |
85 | class AbstractOutcomeRewardModel(ABC):
86 | """abstract base class for outcome reward models"""
87 |
88 | @abstractmethod
89 | async def ascore(
90 | self, prompt_or_messages: str | list[ChatMessage] | ChatMessages, response: str
91 | ) -> float:
92 | """score a response asynchronously"""
93 | pass
94 |
95 | @abstractmethod
96 | def score(
97 | self, prompt_or_messages: str | list[ChatMessage] | ChatMessages, response: str
98 | ) -> float:
99 | """score a response synchronously"""
100 | pass
101 |
102 |
103 | # TODO(GX) deal with aggregation of PRM scores somehow in a common place, e.g. here
104 | class AbstractProcessRewardModel(ABC):
105 | """abstract base class for process reward models"""
106 |
107 | @abstractmethod
108 | async def ascore(
109 | self,
110 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
111 | steps: list[str],
112 | ) -> list[float]:
113 | """score steps asynchronously"""
114 | pass
115 |
116 | @abstractmethod
117 | def score(
118 | self,
119 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
120 | steps: list[str],
121 | ) -> list[float]:
122 | """score steps synchronously"""
123 | pass
124 |
--------------------------------------------------------------------------------
/its_hub/algorithms/bon.py:
--------------------------------------------------------------------------------
1 | from pydantic.dataclasses import dataclass
2 |
3 | from its_hub.base import (
4 | AbstractLanguageModel,
5 | AbstractOutcomeRewardModel,
6 | AbstractScalingAlgorithm,
7 | AbstractScalingResult,
8 | )
9 | from its_hub.types import ChatMessage, ChatMessages
10 | from its_hub.utils import extract_content_from_lm_response
11 |
12 |
13 | def _dedupe_with_inverse(seq: list[str]) -> tuple[list[str], list[int]]:
14 | """
15 | Deduplicate a sequence while preserving order and tracking original indices.
16 |
17 | Returns (uniques, inverse_idx) where:
18 | - uniques: list of unique items in order of first appearance
19 | - inverse_idx: for each item in seq, its index in the uniques list
20 |
21 | Example:
22 | seq = ["a", "b", "a", "c", "b"]
23 | returns (["a", "b", "c"], [0, 1, 0, 2, 1])
24 | """
25 | uniques: list[str] = []
26 | index_of: dict[str, int] = {}
27 | inverse_idx: list[int] = []
28 |
29 | for item in seq:
30 | j = index_of.get(item)
31 | if j is None:
32 | j = len(uniques)
33 | index_of[item] = j
34 | uniques.append(item)
35 | inverse_idx.append(j)
36 |
37 | return uniques, inverse_idx
38 |
39 |
40 | @dataclass
41 | class BestOfNResult(AbstractScalingResult):
42 | responses: list[dict] # Keep original message format with tool calls
43 | scores: list[float]
44 | selected_index: int
45 |
46 | @property
47 | def the_one(self) -> dict:
48 | return self.responses[self.selected_index]
49 |
50 |
51 | class BestOfN(AbstractScalingAlgorithm):
52 | def __init__(self, orm: AbstractOutcomeRewardModel):
53 | self.orm = orm
54 |
55 | async def ainfer(
56 | self,
57 | lm: AbstractLanguageModel,
58 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
59 | budget: int,
60 | return_response_only: bool = True,
61 | tools: list[dict] | None = None,
62 | tool_choice: str | dict | None = None,
63 | ) -> dict | BestOfNResult:
64 | """run inference asynchronously with best-of-n"""
65 | chat_messages = ChatMessages.from_prompt_or_messages(prompt_or_messages)
66 |
67 | # generate responses
68 | responses = await lm.agenerate(
69 | chat_messages.to_batch(budget), tools=tools, tool_choice=tool_choice
70 | )
71 |
72 | # extract content from message dict responses
73 | response_contents = [extract_content_from_lm_response(r) for r in responses]
74 |
75 | # deduplicate responses to avoid redundant scoring
76 | unique_responses, inverse_idx = _dedupe_with_inverse(response_contents)
77 |
78 | # early return if all responses are identical - no need to score
79 | if len(unique_responses) == 1:
80 | scores = [1] * len(responses)
81 | result = BestOfNResult(
82 | responses=responses,
83 | scores=scores,
84 | selected_index=0,
85 | )
86 | return result.the_one if return_response_only else result
87 |
88 | # score only unique responses
89 | # TODO: make batched a configurable parameter or remove non-batched branch
90 | # Currently hardcoded to True, will be addressed in future PR
91 | batched = True
92 | if batched:
93 | unique_scores = await self.orm.ascore(chat_messages, unique_responses)
94 | else:
95 | unique_scores = []
96 | for r in unique_responses:
97 | unique_scores.append(await self.orm.ascore(chat_messages, r))
98 |
99 | # map scores back to original response indices
100 | scores = [unique_scores[idx] for idx in inverse_idx]
101 |
102 | # select the best response
103 | selected_index = scores.index(max(scores))
104 |
105 | # return the result - preserve original message format with tool calls
106 | result = BestOfNResult(
107 | responses=responses, # Keep original dict format with tool calls
108 | scores=scores,
109 | selected_index=selected_index,
110 | )
111 | return result.the_one if return_response_only else result
112 |
--------------------------------------------------------------------------------
/.devcontainer/init-firewall.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -euo pipefail # Exit on error, undefined vars, and pipeline failures
3 | IFS=$'\n\t' # Stricter word splitting
4 |
5 | # Flush existing rules and delete existing ipsets
6 | iptables -F
7 | iptables -X
8 | iptables -t nat -F
9 | iptables -t nat -X
10 | iptables -t mangle -F
11 | iptables -t mangle -X
12 | ipset destroy allowed-domains 2>/dev/null || true
13 |
14 | # First allow DNS and localhost before any restrictions
15 | # Allow outbound DNS
16 | iptables -A OUTPUT -p udp --dport 53 -j ACCEPT
17 | # Allow inbound DNS responses
18 | iptables -A INPUT -p udp --sport 53 -j ACCEPT
19 | # Allow outbound SSH
20 | iptables -A OUTPUT -p tcp --dport 22 -j ACCEPT
21 | # Allow inbound SSH responses
22 | iptables -A INPUT -p tcp --sport 22 -m state --state ESTABLISHED -j ACCEPT
23 | # Allow localhost
24 | iptables -A INPUT -i lo -j ACCEPT
25 | iptables -A OUTPUT -o lo -j ACCEPT
26 |
27 | # Create ipset with CIDR support
28 | ipset create allowed-domains hash:net
29 |
30 | # Fetch GitHub meta information and aggregate + add their IP ranges
31 | echo "Fetching GitHub IP ranges..."
32 | gh_ranges=$(curl -s https://api.github.com/meta)
33 | if [ -z "$gh_ranges" ]; then
34 | echo "ERROR: Failed to fetch GitHub IP ranges"
35 | exit 1
36 | fi
37 |
38 | if ! echo "$gh_ranges" | jq -e '.web and .api and .git' >/dev/null; then
39 | echo "ERROR: GitHub API response missing required fields"
40 | exit 1
41 | fi
42 |
43 | echo "Processing GitHub IPs..."
44 | while read -r cidr; do
45 | if [[ ! "$cidr" =~ ^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}/[0-9]{1,2}$ ]]; then
46 | echo "ERROR: Invalid CIDR range from GitHub meta: $cidr"
47 | exit 1
48 | fi
49 | echo "Adding GitHub range $cidr"
50 | ipset add allowed-domains "$cidr"
51 | done < <(echo "$gh_ranges" | jq -r '(.web + .api + .git)[]' | aggregate -q)
52 |
53 | # Resolve and add other allowed domains
54 | for domain in \
55 | "registry.npmjs.org" \
56 | "api.anthropic.com" \
57 | "sentry.io" \
58 | "statsig.anthropic.com" \
59 | "statsig.com"; do
60 | echo "Resolving $domain..."
61 | ips=$(dig +short A "$domain")
62 | if [ -z "$ips" ]; then
63 | echo "ERROR: Failed to resolve $domain"
64 | exit 1
65 | fi
66 |
67 | while read -r ip; do
68 | if [[ ! "$ip" =~ ^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$ ]]; then
69 | echo "ERROR: Invalid IP from DNS for $domain: $ip"
70 | exit 1
71 | fi
72 | echo "Adding $ip for $domain"
73 | ipset add allowed-domains "$ip"
74 | done < <(echo "$ips")
75 | done
76 |
77 | # Get host IP from default route
78 | HOST_IP=$(ip route | grep default | cut -d" " -f3)
79 | if [ -z "$HOST_IP" ]; then
80 | echo "ERROR: Failed to detect host IP"
81 | exit 1
82 | fi
83 |
84 | HOST_NETWORK=$(echo "$HOST_IP" | sed "s/\.[0-9]*$/.0\/24/")
85 | echo "Host network detected as: $HOST_NETWORK"
86 |
87 | # Set up remaining iptables rules
88 | iptables -A INPUT -s "$HOST_NETWORK" -j ACCEPT
89 | iptables -A OUTPUT -d "$HOST_NETWORK" -j ACCEPT
90 |
91 | # Set default policies to DROP first
92 | iptables -P INPUT DROP
93 | iptables -P FORWARD DROP
94 | iptables -P OUTPUT DROP
95 |
96 | # First allow established connections for already approved traffic
97 | iptables -A INPUT -m state --state ESTABLISHED,RELATED -j ACCEPT
98 | iptables -A OUTPUT -m state --state ESTABLISHED,RELATED -j ACCEPT
99 |
100 | # Then allow only specific outbound traffic to allowed domains
101 | iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT
102 |
103 | echo "Firewall configuration complete"
104 | echo "Verifying firewall rules..."
105 | if curl --connect-timeout 5 https://example.com >/dev/null 2>&1; then
106 | echo "ERROR: Firewall verification failed - was able to reach https://example.com"
107 | exit 1
108 | else
109 | echo "Firewall verification passed - unable to reach https://example.com as expected"
110 | fi
111 |
112 | # Verify GitHub API access
113 | if ! curl --connect-timeout 5 https://api.github.com/zen >/dev/null 2>&1; then
114 | echo "ERROR: Firewall verification failed - unable to reach https://api.github.com"
115 | exit 1
116 | else
117 | echo "Firewall verification passed - able to reach https://api.github.com as expected"
118 | fi
119 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # `its-hub`: A Python library for inference-time scaling
2 |
3 | [](https://github.com/Red-Hat-AI-Innovation-Team/its_hub/actions/workflows/tests.yaml)
4 | [](https://codecov.io/gh/Red-Hat-AI-Innovation-Team/its_hub)
5 | [](https://badge.fury.io/py/its-hub)
6 |
7 | **its_hub** is a Python library for inference-time scaling of LLMs, focusing on mathematical reasoning tasks.
8 |
9 | ## 📚 Documentation
10 |
11 | For comprehensive documentation, including installation guides, tutorials, and API reference, visit:
12 |
13 | **[https://ai-innovation.team/its_hub](https://ai-innovation.team/its_hub)**
14 |
15 | ## Installation
16 |
17 | Choose the installation option based on which algorithms you need:
18 |
19 | ```bash
20 | # Core installation - includes:
21 | # - Best-of-N with LLM Judge
22 | # - Self-Consistency
23 | # - OpenAI-compatible language models
24 | pip install its_hub
25 |
26 | # Process Reward Model installation - adds:
27 | # - Particle Filtering Algorithms
28 | # - Beam Search
29 | # - LocalVllmProcessRewardModel
30 | # - Required for step-by-step reasoning with process reward models
31 | pip install its_hub[prm]
32 |
33 | # Development installation
34 | git clone https://github.com/Red-Hat-AI-Innovation-Team/its_hub.git
35 | cd its_hub
36 | pip install -e ".[dev]"
37 | ```
38 |
39 | ## Quick Start
40 |
41 | ### Example 1: Best-of-N with LLM Judge
42 |
43 | **Installation required:** `pip install its_hub` (core)
44 |
45 | Use Best-of-N algorithm with an LLM judge for response selection - works with any OpenAI-compatible API:
46 |
47 | ```python
48 | from its_hub.lms import OpenAICompatibleLanguageModel
49 | from its_hub.algorithms import BestOfN
50 | from its_hub.integration.reward_hub import LLMJudgeRewardModel
51 |
52 | # Initialize language model
53 | lm = OpenAICompatibleLanguageModel(
54 | endpoint="https://api.openai.com/v1",
55 | api_key="your-api-key",
56 | model_name="gpt-4o-mini",
57 | )
58 |
59 | # Set up LLM judge for scoring
60 | judge = LLMJudgeRewardModel(
61 | model="gpt-4o-mini",
62 | criterion="overall_quality",
63 | judge_type="groupwise",
64 | api_key="your-api-key",
65 | )
66 | scaling_alg = BestOfN(judge)
67 |
68 | # Generate multiple responses and select the best
69 | result = scaling_alg.infer(lm, "Explain quantum entanglement in simple terms", budget=4)
70 | print(result)
71 | ```
72 |
73 | ### Example 2: Particle Filtering with Process Reward Model
74 |
75 | **Installation required:** `pip install its_hub[prm]`
76 |
77 | Use Particle Filtering for step-by-step reasoning with process reward models:
78 |
79 | ```python
80 | from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT
81 | from its_hub.lms import OpenAICompatibleLanguageModel, StepGeneration
82 | from its_hub.algorithms import ParticleFiltering
83 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
84 |
85 | # Initialize language model (requires vLLM server running)
86 | lm = OpenAICompatibleLanguageModel(
87 | endpoint="http://localhost:8100/v1",
88 | api_key="NO_API_KEY",
89 | model_name="Qwen/Qwen2.5-Math-1.5B-Instruct",
90 | system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT,
91 | )
92 |
93 | # Set up step generation and process reward model
94 | sg = StepGeneration(step_token="\n\n", max_steps=32, stop_token=r"\boxed")
95 | prm = LocalVllmProcessRewardModel(
96 | model_name="Qwen/Qwen2.5-Math-PRM-7B",
97 | device="cuda:0",
98 | aggregation_method="prod"
99 | )
100 | scaling_alg = ParticleFiltering(sg, prm)
101 |
102 | # Solve with step-by-step reasoning
103 | result = scaling_alg.infer(lm, "Solve x^2 + 5x + 6 = 0", budget=8)
104 | print(result)
105 | ```
106 |
107 | ## Key Features
108 |
109 | - 🔬 **Multiple Algorithms**: Particle Filtering, Best-of-N, Beam Search, Self-Consistency
110 | - 🚀 **OpenAI-Compatible API**: Easy integration with existing applications
111 | - 🧮 **Math-Optimized**: Built for mathematical reasoning with specialized prompts
112 | - 📊 **Benchmarking Tools**: Compare algorithms on MATH500 and AIME-2024 datasets
113 | - ⚡ **Async Support**: Concurrent generation with limits and error handling
114 |
115 |
116 | For detailed documentation, visit: [https://ai-innovation.team/its_hub](https://ai-innovation.team/its_hub)
117 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Auto generated
2 | its_hub/_version.py
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints/
83 |
84 | # Don't ignore the generated notebooks in notebooks/ directory
85 | # (we want both .py and .ipynb versions)
86 | !notebooks/**/*.ipynb
87 |
88 | # But ignore notebooks in other directories
89 | *.ipynb
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | # For a library or package, you might want to ignore these files since the code is
97 | # intended to run in multiple environments; otherwise, check them in:
98 | # .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # UV
108 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
109 | # This is especially recommended for binary packages to ensure reproducibility, and is more
110 | # commonly ignored for libraries.
111 | uv.lock
112 |
113 | # poetry
114 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
115 | # This is especially recommended for binary packages to ensure reproducibility, and is more
116 | # commonly ignored for libraries.
117 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
118 | #poetry.lock
119 |
120 | # pdm
121 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
122 | #pdm.lock
123 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
124 | # in version control.
125 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
126 | .pdm.toml
127 | .pdm-python
128 | .pdm-build/
129 |
130 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
131 | __pypackages__/
132 |
133 | # Celery stuff
134 | celerybeat-schedule
135 | celerybeat.pid
136 |
137 | # SageMath parsed files
138 | *.sage.py
139 |
140 | # Environments
141 | .env
142 | .venv
143 | env/
144 | venv/
145 | ENV/
146 | env.bak/
147 | venv.bak/
148 |
149 | # Spyder project settings
150 | .spyderproject
151 | .spyproject
152 |
153 | # Rope project settings
154 | .ropeproject
155 |
156 | # mkdocs documentation
157 | /site
158 |
159 | # mypy
160 | .mypy_cache/
161 | .dmypy.json
162 | dmypy.json
163 |
164 | # Pyre type checker
165 | .pyre/
166 |
167 | # pytype static type analyzer
168 | .pytype/
169 |
170 | # Cython debug symbols
171 | cython_debug/
172 |
173 | # PyCharm
174 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
175 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
176 | # and can be added to the global gitignore or merged into this file. For a more nuclear
177 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
178 | #.idea/
179 |
180 | # Ruff stuff:
181 | .ruff_cache/
182 |
183 | # PyPI configuration file
184 | .pypirc
185 |
186 | # local dev
187 | notebooks/dev.ipynb
188 | results/
--------------------------------------------------------------------------------
/tests/mocks/reward_models.py:
--------------------------------------------------------------------------------
1 | """Mock reward models for testing."""
2 |
3 | from its_hub.base import AbstractOutcomeRewardModel
4 |
5 |
6 | class MockOutcomeRewardModel(AbstractOutcomeRewardModel):
7 | """Mock outcome reward model with configurable scores."""
8 |
9 | def __init__(self, scores: list[float] | float):
10 | if isinstance(scores, float):
11 | self.scores = [scores]
12 | else:
13 | self.scores = scores
14 | self.call_count = 0 # tracks total number of individual scores returned
15 | self.score_call_count = 0 # tracks number of times score() is called
16 |
17 | async def ascore(self, prompt: str, response: str | list[str]) -> float | list[float]:
18 | return self.score(prompt, response)
19 |
20 | def score(self, prompt: str, response: str | list[str]) -> float | list[float]:
21 | self.score_call_count += 1
22 | if isinstance(response, list):
23 | scores = []
24 | for i in range(len(response)):
25 | score_idx = (self.call_count + i) % len(self.scores)
26 | scores.append(self.scores[score_idx])
27 | self.call_count += len(response)
28 | return scores
29 | else:
30 | score = self.scores[self.call_count % len(self.scores)]
31 | self.call_count += 1
32 | return score
33 |
34 |
35 | class MockProcessRewardModel:
36 | """Mock process reward model with configurable scores."""
37 |
38 | def __init__(self, scores: list[float] | list[list[float]]):
39 | if isinstance(scores[0], list):
40 | # Flatten nested lists
41 | self.scores = [score for sublist in scores for score in sublist]
42 | else:
43 | self.scores = scores
44 | self.call_count = 0
45 |
46 | async def ascore(self, prompt: str, response: str | list[str]) -> float | list[float]:
47 | return self.score(prompt, response)
48 |
49 | def score(self, prompt: str, response: str | list[str]) -> float | list[float]:
50 | if isinstance(response, list):
51 | scores = []
52 | for i in range(len(response)):
53 | score_idx = (self.call_count + i) % len(self.scores)
54 | scores.append(self.scores[score_idx])
55 | self.call_count += len(response)
56 | return scores
57 | else:
58 | score = self.scores[self.call_count % len(self.scores)]
59 | self.call_count += 1
60 | return score
61 |
62 |
63 | class HighVarianceRewardModel:
64 | """Mock reward model with high variance for testing edge cases."""
65 |
66 | def __init__(self):
67 | self.scores = [0.0, 1.0, 0.5, 0.1, 0.9, 0.3, 0.7, 0.2, 0.8, 0.4]
68 | self.call_count = 0
69 |
70 | async def ascore(self, prompt: str, response: str | list[str]) -> float | list[float]:
71 | return self.score(prompt, response)
72 |
73 | def score(self, prompt: str, response: str | list[str]) -> float | list[float]:
74 | if isinstance(response, list):
75 | scores = []
76 | for i in range(len(response)):
77 | score_idx = (self.call_count + i) % len(self.scores)
78 | scores.append(self.scores[score_idx])
79 | self.call_count += len(response)
80 | return scores
81 | else:
82 | score = self.scores[self.call_count % len(self.scores)]
83 | self.call_count += 1
84 | return score
85 |
86 |
87 | class ErrorRewardModel:
88 | """Mock reward model that can simulate errors."""
89 |
90 | def __init__(self, scores: list[float], error_on_calls: list[int] | None = None):
91 | self.scores = scores
92 | self.error_on_calls = error_on_calls or []
93 | self.call_count = 0
94 |
95 | async def ascore(self, prompt: str, response: str | list[str]) -> float | list[float]:
96 | return self.score(prompt, response)
97 |
98 | def score(self, prompt: str, response: str | list[str]) -> float | list[float]:
99 | if self.call_count in self.error_on_calls:
100 | self.call_count += 1
101 | raise Exception("Simulated reward model error")
102 |
103 | if isinstance(response, list):
104 | scores = []
105 | for i in range(len(response)):
106 | if (self.call_count + i) in self.error_on_calls:
107 | raise Exception("Simulated reward model error in batch")
108 | score_idx = (self.call_count + i) % len(self.scores)
109 | scores.append(self.scores[score_idx])
110 | self.call_count += len(response)
111 | return scores
112 | else:
113 | score = self.scores[self.call_count % len(self.scores)]
114 | self.call_count += 1
115 | return score
116 |
--------------------------------------------------------------------------------
/tests/mocks/test_data.py:
--------------------------------------------------------------------------------
1 | """Test data factories for consistent test data generation."""
2 |
3 | from typing import Any
4 |
5 | from its_hub.types import ChatMessage, ChatMessages
6 |
7 |
8 | class TestDataFactory:
9 | """Factory for creating consistent test data."""
10 |
11 | @staticmethod
12 | def create_chat_messages(
13 | user_content: str = "Hello", system_content: str | None = None
14 | ) -> ChatMessages:
15 | """Create ChatMessages object for testing."""
16 | if system_content:
17 | messages = [
18 | ChatMessage(role="system", content=system_content),
19 | ChatMessage(role="user", content=user_content),
20 | ]
21 | return ChatMessages(messages)
22 | else:
23 | return ChatMessages(user_content) # Simple string case
24 |
25 | @staticmethod
26 | def create_chat_completion_request(
27 | model: str = "test-model",
28 | user_content: str = "Hello",
29 | budget: int = 4,
30 | system_content: str | None = None,
31 | **kwargs,
32 | ) -> dict[str, Any]:
33 | """Create a standard chat completion request."""
34 | chat_messages = TestDataFactory.create_chat_messages(
35 | user_content, system_content
36 | ).to_chat_messages()
37 | request = {
38 | "model": model,
39 | "messages": [
40 | msg.__dict__ for msg in chat_messages
41 | ], # Convert to dict for JSON serialization
42 | "budget": budget,
43 | }
44 | request.update(kwargs)
45 | return request
46 |
47 | @staticmethod
48 | def create_config_request(
49 | endpoint: str = "http://localhost:8000",
50 | api_key: str = "test-key",
51 | model: str = "test-model",
52 | alg: str = "best-of-n",
53 | **kwargs,
54 | ) -> dict[str, Any]:
55 | """Create a standard configuration request."""
56 | config = {
57 | "endpoint": endpoint,
58 | "api_key": api_key,
59 | "model": model,
60 | "alg": alg,
61 | "rm_name": "test-rm",
62 | "rm_device": "cpu",
63 | }
64 | config.update(kwargs)
65 | return config
66 |
67 | @staticmethod
68 | def create_error_trigger_request(trigger: str = "trigger_error") -> dict[str, Any]:
69 | """Create a request that triggers errors for testing."""
70 | return TestDataFactory.create_chat_completion_request(user_content=trigger)
71 |
72 | @staticmethod
73 | def create_multiple_responses(base: str = "response", count: int = 3) -> list[str]:
74 | """Create multiple response strings for testing."""
75 | return [f"{base}{i + 1}" for i in range(count)]
76 |
77 | @staticmethod
78 | def create_score_sequence(
79 | base_score: float = 0.5, count: int = 3, increment: float = 0.1
80 | ) -> list[float]:
81 | """Create a sequence of scores for testing."""
82 | return [base_score + (i * increment) for i in range(count)]
83 |
84 |
85 | # Common test scenarios
86 | TEST_SCENARIOS = {
87 | "simple_chat": {
88 | "user_content": "Hello, world!",
89 | "expected_response": {
90 | "role": "assistant",
91 | "content": "Response to: Hello, world!",
92 | },
93 | },
94 | "math_problem": {
95 | "user_content": "Solve 2+2",
96 | "expected_response": {"role": "assistant", "content": "Response to: Solve 2+2"},
97 | },
98 | "error_trigger": {"user_content": "trigger_error", "should_error": True},
99 | "vllm_error_trigger": {"user_content": "error", "should_error": True},
100 | "with_system_prompt": {
101 | "system_content": "You are a helpful assistant",
102 | "user_content": "How can I help you?",
103 | "expected_response": {
104 | "role": "assistant",
105 | "content": "Response to: How can I help you?",
106 | },
107 | },
108 | }
109 |
110 | # Algorithm test configurations
111 | ALGORITHM_CONFIGS = {
112 | "best_of_n": {
113 | "alg": "best-of-n",
114 | "requires_outcome_rm": True,
115 | "supports_batching": True,
116 | },
117 | "beam_search": {
118 | "alg": "beam-search",
119 | "requires_process_rm": True,
120 | "requires_step_generation": True,
121 | "budget_constraints": "divisible_by_beam_width",
122 | },
123 | "particle_filtering": {
124 | "alg": "particle-filtering",
125 | "requires_process_rm": True,
126 | "requires_step_generation": True,
127 | "supports_selection_methods": ["argmax", "sample"],
128 | },
129 | "particle_gibbs": {
130 | "alg": "particle-gibbs",
131 | "requires_process_rm": True,
132 | "requires_step_generation": True,
133 | "supports_selection_methods": ["argmax", "sample"],
134 | "budget_constraints": "divisible_by_iterations",
135 | },
136 | }
137 |
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | ## Prerequisites
4 |
5 | - Python 3.10+ (3.11+ recommended)
6 | - pip or uv package manager
7 | - GPU with CUDA 11.8+ (only for `[prm]` installation)
8 |
9 | ## Installation Options
10 |
11 | | Option | Command | Use Case |
12 | |--------|---------|----------|
13 | | **Core** | `pip install its_hub` | Best-of-N, Self-Consistency, cloud APIs |
14 | | **PRM** | `pip install its_hub[prm]` | Particle Filtering, Beam Search, local reward models |
15 | | **Cloud** | `pip install its_hub[cloud]` | AWS Bedrock, Google Vertex AI |
16 | | **Research** | `pip install its_hub[research]` | Benchmarks, evaluation tools |
17 | | **Dev** | `pip install -e ".[dev]"` | Contributing, testing |
18 |
19 | ---
20 |
21 | ## Core Installation
22 |
23 | ```bash
24 | pip install its_hub
25 | ```
26 |
27 | ### What's Included
28 |
29 | **Algorithms**: Best-of-N, Self-Consistency, LLM Judge
30 | **Language Models**: OpenAI-compatible, LiteLLM (100+ providers)
31 | **Key Dependencies**: `openai`, `litellm`, `reward-hub`, `transformers`, `fastapi`
32 |
33 | ### When to Use
34 |
35 | **Use if**: Working with cloud APIs (OpenAI, Anthropic, etc.), no GPU needed
36 | **Skip if**: Need Particle Filtering/Beam Search or local process reward models
37 |
38 | ### Under the Hood
39 |
40 | - **Size**: ~50MB (no vLLM or CUDA dependencies)
41 | - **Installation time**: 1-2 minutes
42 | - **GPU required**: No
43 | - **What's excluded**: vLLM, local reward model inference
44 |
45 | ```python
46 | # Verify installation
47 | from its_hub.algorithms import BestOfN, SelfConsistency
48 | from its_hub.integration.reward_hub import LLMJudgeRewardModel
49 | ```
50 |
51 | ---
52 |
53 | ## Process Reward Model (PRM) Installation
54 |
55 | ```bash
56 | pip install its_hub[prm]
57 | ```
58 |
59 | ### What's Added
60 |
61 | **Algorithms**: Particle Filtering, Beam Search (+ all core algorithms)
62 | **Reward Models**: `LocalVllmProcessRewardModel` for step-by-step scoring
63 | **Additional Dependencies**: `reward-hub[prm]` (includes vLLM with pinned versions)
64 |
65 | ### When to Use
66 |
67 | **Use if**: Need step-by-step reasoning with local reward models, have GPU
68 | **Skip if**: Only using cloud APIs or outcome-based scoring
69 |
70 | ### Under the Hood
71 |
72 | - **Size**: ~2-3GB (includes vLLM + CUDA dependencies)
73 | - **Installation time**: 5-10 minutes
74 | - **GPU required**: Yes (10-20GB VRAM for typical 7B reward models)
75 | - **Version pinning**: `reward-hub[prm]` pins compatible vLLM + transformers + PyTorch versions
76 |
77 | ```python
78 | # Verify installation
79 | from its_hub.algorithms import ParticleFiltering, BeamSearch
80 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
81 |
82 | # Check GPU
83 | import torch
84 | print(f'CUDA available: {torch.cuda.is_available()}')
85 | ```
86 |
87 | ---
88 |
89 | ## Cloud Installation
90 |
91 | ```bash
92 | pip install its_hub[cloud]
93 | ```
94 |
95 | **Adds**: AWS Bedrock (`boto3`) and Google Vertex AI (`google-cloud-aiplatform`) SDKs
96 | **Use if**: Need direct SDK access to Bedrock or Vertex AI (most cloud providers work with core via LiteLLM)
97 |
98 | ---
99 |
100 | ## Research Installation
101 |
102 | ```bash
103 | pip install its_hub[research]
104 | ```
105 |
106 | **Adds**: `math-verify`, `datasets`, `matplotlib`
107 | **Use if**: Running benchmarks on MATH500/AIME or evaluating algorithm performance
108 | **Includes**: Benchmark scripts in `scripts/benchmark.py`
109 |
110 | ---
111 |
112 | ## Development Installation
113 |
114 | ```bash
115 | git clone https://github.com/Red-Hat-AI-Innovation-Team/its_hub.git
116 | cd its_hub
117 |
118 | # Recommended: uv
119 | uv sync --extra dev
120 |
121 | # Alternative: pip
122 | pip install -e ".[dev]"
123 | ```
124 |
125 | **Includes**: All core + PRM + `pytest`, `ruff`, `jupyter`, notebooks
126 | **Use if**: Contributing, testing, or developing new features
127 |
128 | ```bash
129 | # Run tests
130 | uv run pytest tests/
131 | uv run pytest tests/ --cov=its_hub
132 |
133 | # Code quality
134 | uv run ruff check its_hub/ --fix
135 | uv run ruff format its_hub/
136 | ```
137 |
138 | ---
139 |
140 | ## Combining Extras
141 |
142 | ```bash
143 | pip install its_hub[prm,research] # PRM + benchmarking
144 | pip install its_hub[cloud,research] # Cloud + benchmarking
145 | pip install -e ".[dev,research,cloud]" # Everything
146 | ```
147 |
148 | ---
149 |
150 | ## Verification
151 |
152 | ```bash
153 | # Core
154 | python -c "from its_hub.algorithms import BestOfN; print('✅ Core OK')"
155 |
156 | # PRM
157 | python -c "from its_hub.algorithms import ParticleFiltering; print('✅ PRM OK')"
158 | python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')"
159 | ```
160 |
161 | ---
162 |
163 | ## Next Steps
164 |
165 | - [Quick Start Guide](quick-start.md) - Best-of-N and Particle Filtering examples
166 | - [IaaS Service Guide](iaas-service.md) - Deploy as OpenAI-compatible API
167 | - [Development Guide](development.md) - Contributing guidelines
168 |
169 | For runtime issues (CUDA OOM, server errors, etc.), see the troubleshooting sections in the Quick Start or IaaS Service guides.
--------------------------------------------------------------------------------
/tests/mocks/language_models.py:
--------------------------------------------------------------------------------
1 | """Mock language models for testing."""
2 |
3 | from its_hub.base import AbstractLanguageModel
4 |
5 |
6 | class SimpleMockLanguageModel:
7 | """Simple mock language model for basic testing."""
8 |
9 | def __init__(self, responses: list[str]):
10 | self.responses = responses
11 | self.call_count = 0
12 |
13 | async def agenerate(self, messages, **kwargs):
14 | return self.generate(messages, **kwargs)
15 |
16 | def generate(self, messages, **kwargs):
17 | if isinstance(messages[0], list):
18 | # Multiple message lists
19 | content_responses = self.responses[
20 | self.call_count : self.call_count + len(messages)
21 | ]
22 | self.call_count += len(messages)
23 | return [
24 | {"role": "assistant", "content": content}
25 | for content in content_responses
26 | ]
27 | else:
28 | # Single message list
29 | content = self.responses[self.call_count]
30 | self.call_count += 1
31 | return {"role": "assistant", "content": content}
32 |
33 |
34 | class StepMockLanguageModel(AbstractLanguageModel):
35 | """Mock language model for step-by-step generation testing."""
36 |
37 | def __init__(self, step_responses: list[str]):
38 | self.step_responses = step_responses
39 | self.call_count = 0
40 |
41 | async def agenerate(self, messages, stop=None, max_tokens=None, temperature=None, include_stop_str_in_output=None, tools=None, tool_choice=None):
42 | return self.generate(messages, stop, max_tokens, temperature, include_stop_str_in_output, tools, tool_choice)
43 |
44 | def generate(self, messages, stop=None, max_tokens=None, temperature=None, include_stop_str_in_output=None, tools=None, tool_choice=None):
45 | if isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list):
46 | # Batched generation
47 | num_requests = len(messages)
48 | responses = []
49 | for i in range(num_requests):
50 | response_idx = (self.call_count + i) % len(self.step_responses)
51 | content = self.step_responses[response_idx]
52 | responses.append({"role": "assistant", "content": content})
53 | self.call_count += num_requests
54 | return responses
55 | else:
56 | # Single generation
57 | content = self.step_responses[self.call_count % len(self.step_responses)]
58 | self.call_count += 1
59 | return {"role": "assistant", "content": content}
60 |
61 | async def aevaluate(self, prompt: str, generation: str) -> list[float]:
62 | return self.evaluate(prompt, generation)
63 |
64 | def evaluate(self, prompt: str, generation: str) -> list[float]:
65 | """Return mock evaluation scores."""
66 | return [0.1] * len(generation.split())
67 |
68 |
69 | class ErrorMockLanguageModel(AbstractLanguageModel):
70 | """Mock language model that can simulate errors."""
71 |
72 | def __init__(self, responses: list[str], error_on_calls: list[int] | None = None):
73 | self.responses = responses
74 | self.error_on_calls = error_on_calls or []
75 | self.call_count = 0
76 |
77 | async def agenerate(self, messages, stop=None, max_tokens=None, temperature=None, include_stop_str_in_output=None, tools=None, tool_choice=None):
78 | return self.generate(messages, stop, max_tokens, temperature, include_stop_str_in_output, tools, tool_choice)
79 |
80 | def generate(self, messages, stop=None, max_tokens=None, temperature=None, include_stop_str_in_output=None, tools=None, tool_choice=None):
81 | if self.call_count in self.error_on_calls:
82 | self.call_count += 1
83 | raise Exception("Simulated LM error")
84 |
85 | if (
86 | isinstance(messages, list)
87 | and len(messages) > 0
88 | and isinstance(messages[0], list)
89 | ):
90 | # Batched generation
91 | num_requests = len(messages)
92 | responses = []
93 | for i in range(num_requests):
94 | if (self.call_count + i) in self.error_on_calls:
95 | raise Exception("Simulated LM error in batch")
96 | response_idx = (self.call_count + i) % len(self.responses)
97 | content = self.responses[response_idx]
98 | responses.append({"role": "assistant", "content": content})
99 | self.call_count += num_requests
100 | return responses
101 | else:
102 | # Single generation
103 | content = self.responses[self.call_count % len(self.responses)]
104 | self.call_count += 1
105 | return {"role": "assistant", "content": content}
106 |
107 | async def aevaluate(self, prompt: str, generation: str) -> list[float]:
108 | return self.evaluate(prompt, generation)
109 |
110 | def evaluate(self, prompt: str, generation: str) -> list[float]:
111 | return [0.1] * len(generation.split())
112 |
--------------------------------------------------------------------------------
/notebooks/self-consistency.py:
--------------------------------------------------------------------------------
1 | # ---
2 | # jupyter:
3 | # jupytext:
4 | # cell_metadata_filter: -all
5 | # notebook_metadata_filter: all
6 | # text_representation:
7 | # extension: .py
8 | # format_name: percent
9 | # format_version: '1.3'
10 | # jupytext_version: 1.18.1
11 | # kernelspec:
12 | # display_name: inference_time_scaling-dev
13 | # language: python
14 | # name: python3
15 | # language_info:
16 | # codemirror_mode:
17 | # name: ipython
18 | # version: 3
19 | # file_extension: .py
20 | # mimetype: text/x-python
21 | # name: python
22 | # nbconvert_exporter: python
23 | # pygments_lexer: ipython3
24 | # version: 3.11.11
25 | # ---
26 |
27 | # %% [markdown]
28 | # # Self-Consistency Algorithm Demo
29 | # This notebook demonstrates the Self-Consistency algorithm for mathematical reasoning.
30 |
31 | # %%
32 | # %load_ext autoreload
33 | # %autoreload 2
34 |
35 | # %%
36 | import os
37 |
38 | import nest_asyncio
39 | from dotenv import load_dotenv
40 |
41 | from its_hub.lms import OpenAICompatibleLanguageModel
42 | from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT
43 |
44 | nest_asyncio.apply()
45 |
46 | # Load environment variables from .env file
47 | load_dotenv()
48 |
49 | # Main example: OpenAI API endpoint with gpt-4o-mini
50 | lm = OpenAICompatibleLanguageModel(
51 | endpoint="https://api.openai.com/v1",
52 | api_key=os.getenv("OPENAI_API_KEY"), # Load API key from environment
53 | model_name="gpt-4o-mini",
54 | system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT,
55 | is_async=True,
56 | )
57 | # %%
58 | # Alternative: vLLM local endpoint (commented out)
59 | # lm = OpenAICompatibleLanguageModel(
60 | # endpoint="http://localhost:8000/v1",
61 | # api_key="NO_API_KEY",
62 | # model_name="qwen2-math-1.5b-instruct",
63 | # system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT,
64 | # is_async=True,
65 | # )
66 |
67 | # %%
68 | # Mathematical problem to solve
69 | prompt = r"Let $a$ be a positive real number such that all the roots of \[x^3 + ax^2 + ax + 1 = 0\]are real. Find the smallest possible value of $a.$"
70 |
71 | # Generate response using the proper format
72 | from its_hub.types import ChatMessages
73 |
74 | chat_messages = ChatMessages.from_prompt_or_messages(prompt)
75 | response = lm.generate(chat_messages.to_batch(1))[0]
76 |
77 | print(response)
78 |
79 |
80 | # %%
81 | def extract_boxed(s: str) -> str:
82 | import re
83 | # find all occurrences of \boxed{...}
84 | boxed_matches = re.findall(r'\\boxed\{([^{}]+(?:\{[^{}]*\}[^{}]*)*)\}', s)
85 | # return the last match if any were found
86 | return boxed_matches[-1] if boxed_matches else ""
87 |
88 | print(extract_boxed(response['content']))
89 |
90 | # %% [markdown]
91 | # ## Self-Consistency Algorithm
92 | # Now we'll use the Self-Consistency algorithm to improve the answer quality.
93 |
94 | # %%
95 | from its_hub.algorithms import SelfConsistency
96 |
97 | # Set computational budget for scaling
98 | budget = 4
99 |
100 | scaling_alg = SelfConsistency(extract_boxed)
101 |
102 | scaling_result = scaling_alg.infer(
103 | lm, prompt, budget, return_response_only=False
104 | )
105 |
106 | print("######## Self-Consistency Result ########")
107 | print(scaling_result.the_one)
108 |
109 | # %%
110 | print("######## Extracted Response Counts ########")
111 | print(scaling_result.response_counts)
112 |
113 | # %%
114 |
115 |
116 | # %% [markdown]
117 | # ## Self-Consistency Algorithm for Tool Calls
118 | # We have hierarchical tool-voting support in Self-Consistency algorithm
119 | # It first votes on tool names, and then on tool arguments.
120 |
121 | # %%
122 | from its_hub.types import ChatMessage, ChatMessages
123 |
124 | # Tool schema (OpenAI-style dicts)
125 | tools = [
126 | {
127 | "type": "function",
128 | "function": {
129 | "name": "calculator",
130 | "description": "Perform arithmetic calculations",
131 | "parameters": {
132 | "type": "object",
133 | "properties": {
134 | "expression": {
135 | "type": "string",
136 | "description": "Mathematical expression to evaluate"
137 | }
138 | },
139 | "required": ["expression"]
140 | }
141 | }
142 | }
143 | ]
144 |
145 | # ChatMessages instance with system + user
146 | tool_call_messages = ChatMessages([
147 | ChatMessage(
148 | role="system",
149 | content="You are a precise calculator. Always use the calculator tool for arithmetic and format your final answer as \\boxed{result}."
150 | ),
151 | ChatMessage(
152 | role="user",
153 | content="What is 847 * 293 + 156?"
154 | ),
155 | ])
156 |
157 | # %%
158 | # Use hierarchical tool voting
159 | scaling_alg_tool = SelfConsistency(tool_vote="tool_hierarchical")
160 |
161 | budget = 5
162 | scaling_result = scaling_alg_tool.infer(
163 | lm, tool_call_messages, budget, return_response_only=False, tools=tools, tool_choice="auto"
164 | )
165 |
166 | # %%
167 | print("######## Self-Consistency Result ########")
168 | print(scaling_result.the_one)
169 |
170 | print("######## Tool Call Response Counts ########")
171 | print(scaling_result.response_counts)
172 |
173 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | name: Build, test, and upload PyPI package
4 |
5 | on:
6 | push:
7 | branches:
8 | - "main"
9 | - "release-**"
10 | tags:
11 | - "v*"
12 | pull_request:
13 | branches:
14 | - "main"
15 | - "release-**"
16 | release:
17 | types:
18 | - published
19 |
20 | env:
21 | LC_ALL: en_US.UTF-8
22 |
23 | defaults:
24 | run:
25 | shell: bash
26 |
27 | permissions:
28 | contents: read
29 |
30 | jobs:
31 | # Create and verify release artifacts
32 | # - build source dist (tar ball) and wheel
33 | # - validate artifacts with various tools
34 | # - upload artifacts to GHA
35 | build-package:
36 | name: Build and check packages
37 | runs-on: ubuntu-latest
38 | steps:
39 | - name: "Harden Runner"
40 | uses: step-security/harden-runner@c6295a65d1254861815972266d5933fd6e532bdf # v2.11.1
41 | with:
42 | egress-policy: audit # TODO: change to 'egress-policy: block' after couple of runs
43 |
44 |
45 | - name: "Checkout"
46 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
47 | with:
48 | # for setuptools-scm
49 | fetch-depth: 0
50 |
51 | - name: "Build and Inspect"
52 | uses: hynek/build-and-inspect-python-package@efb823f52190ad02594531168b7a2d5790e66516 # v2.14.0
53 |
54 | # push to Test PyPI on
55 | # - a new GitHub release is published
56 | # - a PR is merged into main branch
57 | publish-test-pypi:
58 | name: Publish packages to test.pypi.org
59 | # environment: publish-test-pypi
60 | if: ${{ (github.repository_owner == 'Red-Hat-AI-Innovation-Team') && ((github.event.action == 'published') || ((github.event_name == 'push') && (github.ref == 'refs/heads/main'))) }}
61 | permissions:
62 | contents: read
63 | # see https://docs.pypi.org/trusted-publishers/
64 | id-token: write
65 | runs-on: ubuntu-latest
66 | needs: build-package
67 |
68 | environment:
69 | name: testpypi
70 | url: https://test.pypi.org/p/its-hub
71 |
72 | steps:
73 | - name: "Harden Runner"
74 | uses: step-security/harden-runner@c6295a65d1254861815972266d5933fd6e532bdf # v2.11.1
75 | with:
76 | egress-policy: audit # TODO: change to 'egress-policy: block' after couple of runs
77 |
78 | - name: "Download build artifacts"
79 | uses: actions/download-artifact@cc203385981b70ca67e1cc392babf9cc229d5806 # v4.1.9
80 | with:
81 | name: Packages
82 | path: dist
83 |
84 | - name: "Upload to Test PyPI"
85 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4
86 | with:
87 | repository-url: https://test.pypi.org/legacy/
88 |
89 | # push to Production PyPI on
90 | # - a new GitHub release is published
91 | publish-pypi:
92 | name: Publish release to pypi.org
93 | # environment: publish-pypi
94 | if: ${{ (github.repository_owner == 'Red-Hat-AI-Innovation-Team') && (github.event.action == 'published') }}
95 | permissions:
96 | # see https://docs.pypi.org/trusted-publishers/
97 | id-token: write
98 | # allow gh release upload
99 | contents: write
100 |
101 | environment:
102 | name: pypi
103 | url: https://pypi.org/p/its-hub
104 |
105 | runs-on: ubuntu-latest
106 | needs: build-package
107 |
108 | steps:
109 | - name: "Harden Runner"
110 | uses: step-security/harden-runner@c6295a65d1254861815972266d5933fd6e532bdf # v2.11.1
111 | with:
112 | egress-policy: audit # TODO: change to 'egress-policy: block' after couple of runs
113 |
114 | - name: "Download build artifacts"
115 | uses: actions/download-artifact@cc203385981b70ca67e1cc392babf9cc229d5806 # v4.1.9
116 | with:
117 | name: Packages
118 | path: dist
119 |
120 | - name: "Sigstore sign package"
121 | uses: sigstore/gh-action-sigstore-python@f514d46b907ebcd5bedc05145c03b69c1edd8b46 # v3.0.0
122 | with:
123 | inputs: |
124 | ./dist/*.tar.gz
125 | ./dist/*.whl
126 | release-signing-artifacts: false
127 |
128 | - name: "Upload artifacts and signatures to GitHub release"
129 | run: |
130 | gh release upload '${{ github.ref_name }}' dist/* --repo '${{ github.repository }}'
131 | env:
132 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
133 |
134 | # PyPI does not accept .sigstore artifacts and
135 | # gh-action-pypi-publish has no option to ignore them.
136 | - name: "Remove sigstore signatures before uploading to PyPI"
137 | run: |
138 | rm ./dist/*.sigstore.json
139 |
140 | - name: "Upload to PyPI"
141 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4
142 |
--------------------------------------------------------------------------------
/docs/PLANNING_WRAPPER.md:
--------------------------------------------------------------------------------
1 | # Planning Wrapper for ITS Algorithms
2 |
3 | ## Overview
4 |
5 | The `PlanningWrapper` is a generic enhancement that adds a planning phase to any inference-time scaling (ITS) algorithm. It allows the model to generate multiple solution approaches before execution, potentially improving performance through diverse strategy exploration.
6 |
7 | ## Key Features
8 |
9 | - **Universal Compatibility**: Works with any ITS algorithm (Self-Consistency, Best-of-N, Particle Filtering, Beam Search)
10 | - **Unified Interface**: Maintains the same `infer()` method signature across all enhanced algorithms
11 | - **Smart Budget Allocation**: Automatically divides computational budget across planned approaches
12 | - **Robust Plan Parsing**: Handles various plan formats with intelligent fallbacks
13 |
14 | ## Architecture
15 |
16 | ### Core Components
17 |
18 | 1. **PlanningWrapper**: Main class that wraps any base ITS algorithm
19 | 2. **PlanningPromptTemplate**: Generates prompts that encourage diverse approach planning
20 | 3. **PlanParser**: Extracts structured approaches from natural language plans
21 | 4. **ApproachPromptTemplate**: Creates approach-specific prompts for execution
22 |
23 | ### Process Flow
24 |
25 | 1. **Planning Phase**: Generate plan with 3 distinct approaches (costs 1 from budget)
26 | 2. **Approach Parsing**: Extract approaches using regex patterns with fallbacks
27 | 3. **Budget Allocation**: Divide remaining budget equally across approaches
28 | 4. **Execution**: Run base algorithm for each approach with approach-specific prompts
29 | 5. **Selection**: Choose best result based on algorithm-specific scoring
30 |
31 | ## Usage
32 |
33 | ### Manual Wrapping
34 | ```python
35 | from its_hub.algorithms.planning_wrapper import PlanningWrapper
36 | from its_hub.algorithms import SelfConsistency
37 |
38 | base_algorithm = SelfConsistency(extract_fn)
39 | planning_algorithm = PlanningWrapper(base_algorithm)
40 |
41 | result = planning_algorithm.infer(lm, prompt, budget=16, return_response_only=False)
42 | ```
43 |
44 | ### Convenience Functions
45 | ```python
46 | from its_hub.algorithms.planning_wrapper import (
47 | create_planning_self_consistency,
48 | create_planning_particle_filtering,
49 | create_planning_best_of_n,
50 | create_planning_beam_search
51 | )
52 |
53 | # Enhanced algorithms
54 | planning_sc = create_planning_self_consistency(extract_fn)
55 | planning_pf = create_planning_particle_filtering(sg, prm)
56 | planning_bon = create_planning_best_of_n(orm)
57 | planning_bs = create_planning_beam_search(sg, prm, beam_width=4)
58 |
59 | # Same interface for all
60 | result = planning_sc.infer(lm, prompt, budget=16, return_response_only=False)
61 | ```
62 |
63 | ### Result Object
64 | ```python
65 | # Planning-enhanced results include additional information
66 | result = planning_algorithm.infer(lm, prompt, budget=16, return_response_only=False)
67 |
68 | print(f"Best answer: {result.the_one}")
69 | print(f"Generated plan: {result.plan}")
70 | print(f"Approaches used: {result.approaches}")
71 | print(f"Best approach: {result.best_approach}")
72 | print(f"Budget allocation: {result.approach_budgets}")
73 | ```
74 |
75 | ## Supported Algorithms
76 |
77 | - ✅ **Self-Consistency**: Enhanced with planning via `create_planning_self_consistency()`
78 | - ✅ **Best-of-N**: Enhanced with planning via `create_planning_best_of_n()`
79 | - ✅ **Particle Filtering**: Enhanced with planning via `create_planning_particle_filtering()`
80 | - ✅ **Beam Search**: Enhanced with planning via `create_planning_beam_search()`
81 |
82 | ## Testing
83 |
84 | Run the comprehensive test suite:
85 | ```bash
86 | python test_planning_wrapper.py
87 | ```
88 |
89 | This test validates:
90 | - Planning-enhanced versions of all supported algorithms
91 | - Proper budget allocation across approaches
92 | - Result aggregation and best approach selection
93 | - Fallback handling for plan parsing failures
94 |
95 | ## Implementation Details
96 |
97 | ### Plan Generation
98 | The wrapper generates plans using a structured prompt that encourages the model to think of 3 distinct mathematical approaches:
99 |
100 | ```
101 | APPROACH 1: [Brief description of first method/strategy]
102 | APPROACH 2: [Brief description of second method/strategy]
103 | APPROACH 3: [Brief description of third method/strategy]
104 | ```
105 |
106 | ### Budget Allocation
107 | - Planning phase uses 1 generation from the total budget
108 | - Remaining budget is divided equally across parsed approaches
109 | - Any remainder is distributed to the first few approaches
110 |
111 | ### Approach Selection
112 | The wrapper selects the best approach based on algorithm-specific scoring:
113 | - Tries various score attributes (`best_score`, `confidence`, `scores`, etc.)
114 | - Falls back to response length as a proxy for quality
115 | - Returns the approach with the highest score
116 |
117 | ### Error Handling
118 | - Robust plan parsing with regex patterns and fallbacks
119 | - Generic fallback approaches if parsing fails completely
120 | - Graceful handling of missing score attributes
121 |
122 | ## Performance Considerations
123 |
124 | - **Overhead**: 1 additional generation for planning
125 | - **Benefits**: Potentially better results through diverse approaches
126 | - **Trade-offs**: Lower budgets may suffer from planning overhead, higher budgets benefit more
127 |
128 | ## Future Enhancements
129 |
130 | - Adaptive planning based on problem complexity
131 | - Dynamic budget allocation based on approach confidence
132 | - Cross-approach result fusion techniques
133 | - Problem-specific approach templates
--------------------------------------------------------------------------------
/its_hub/types.py:
--------------------------------------------------------------------------------
1 | """Type definitions for its_hub."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Literal
6 |
7 | from pydantic.dataclasses import dataclass
8 |
9 |
10 | @dataclass
11 | class Function:
12 | """Function definition for tool calls."""
13 |
14 | name: str
15 | description: str | None = None
16 | parameters: dict | None = None
17 |
18 |
19 | @dataclass
20 | class ToolCall:
21 | """A tool call made by the assistant."""
22 |
23 | id: str
24 | type: Literal["function"] = "function"
25 | function: Function | None = None
26 |
27 |
28 | @dataclass
29 | class ChatMessage:
30 | """A chat message with role and content.
31 | Content can be:
32 | - str: Simple text content
33 | - list[dict]: Multi-modal content (text, images, etc.)
34 | - None: No content (e.g., when using tool_calls)
35 | """
36 |
37 | role: Literal["system", "user", "assistant", "tool"]
38 | content: str | list[dict] | None
39 | tool_calls: list[dict] | None = None # Store as plain dicts, not Pydantic objects
40 | tool_call_id: str | None = None
41 |
42 | def extract_text_content(self) -> str:
43 | """Extract text content from message, handling both string and list formats.
44 | For list content (multi-modal), extracts all text parts and warns about non-text content.
45 | Returns empty string if no text content is found.
46 | """
47 | if self.content is None:
48 | return ""
49 |
50 | if isinstance(self.content, str):
51 | return self.content
52 |
53 | # Must be list[dict] at this point
54 | text_parts = []
55 | has_image = False
56 |
57 | for item in self.content:
58 | content_type = item.get("type", "")
59 |
60 | if content_type == "text":
61 | text_parts.append(item.get("text", ""))
62 | elif content_type == "image_url":
63 | has_image = True
64 | elif content_type:
65 | raise ValueError(
66 | f"Unsupported content type '{content_type}' in messages content dict."
67 | )
68 |
69 | if has_image:
70 | logging.warning(
71 | "Image content detected in message but is not supported. "
72 | "Image content will be ignored. Only text content is processed."
73 | )
74 |
75 | return " ".join(text_parts)
76 |
77 | def to_dict(self) -> dict:
78 | """Convert ChatMessage to dictionary, excluding None values."""
79 | result = {"role": self.role}
80 | if self.content is not None:
81 | result["content"] = self.content
82 | if self.tool_calls is not None:
83 | result["tool_calls"] = self.tool_calls
84 | if self.tool_call_id is not None:
85 | result["tool_call_id"] = self.tool_call_id
86 | return result
87 |
88 |
89 | class ChatMessages:
90 | """Unified wrapper for handling both string prompts and conversation history."""
91 |
92 | def __init__(self, str_or_messages: str | list[ChatMessage]):
93 | self._str_or_messages = str_or_messages
94 | self._is_string = isinstance(str_or_messages, str)
95 |
96 | @classmethod
97 | def from_prompt_or_messages(
98 | cls, prompt_or_messages: str | list[ChatMessage] | ChatMessages
99 | ) -> ChatMessages:
100 | """Create ChatMessages from various input formats."""
101 | if isinstance(prompt_or_messages, ChatMessages):
102 | return prompt_or_messages
103 | return cls(prompt_or_messages)
104 |
105 | def to_prompt(self) -> str:
106 | # TODO: chatMessage to string conversion will be deprecated in the future.
107 | """Convert to prompt string representation."""
108 | if self._is_string:
109 | return self._str_or_messages
110 |
111 | lines = []
112 | for msg in self._str_or_messages:
113 | text_content = msg.extract_text_content()
114 |
115 | if msg.role == "tool":
116 | # Tool messages: include tool_call_id context
117 | lines.append(f"tool[{msg.tool_call_id}]: {text_content}")
118 | elif msg.role == "assistant" and msg.tool_calls:
119 | # Assistant with tool calls: show tool calls + content if any
120 | tool_call_strs = []
121 | for tc in msg.tool_calls:
122 | if tc.function:
123 | tool_call_strs.append(f"{tc.function.name}()")
124 | tool_calls_text = ", ".join(tool_call_strs)
125 | if text_content:
126 | lines.append(
127 | f"assistant: {text_content} [calls: {tool_calls_text}]"
128 | )
129 | else:
130 | lines.append(f"assistant: [calls: {tool_calls_text}]")
131 | else:
132 | # Regular messages
133 | lines.append(f"{msg.role}: {text_content}")
134 |
135 | return "\n".join(lines)
136 |
137 | def to_chat_messages(self) -> list[ChatMessage]:
138 | """Convert to list of ChatMessage objects."""
139 | if self._is_string:
140 | return [ChatMessage(role="user", content=self._str_or_messages)]
141 | return self._str_or_messages
142 |
143 | def to_batch(self, size: int) -> list[list[ChatMessage]]:
144 | """Create a batch of identical chat message lists for parallel generation."""
145 | chat_messages = self.to_chat_messages()
146 | return [chat_messages for _ in range(size)]
147 |
148 | @property
149 | def is_string(self) -> bool:
150 | """Check if the original input was a string."""
151 | return self._is_string
152 |
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
1 | # CLAUDE.md
2 |
3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4 |
5 | ## Development Commands
6 |
7 | ### Installation and Setup
8 | ```bash
9 | # Development installation with uv (recommended)
10 | uv sync --extra dev
11 |
12 | # Alternative: pip installation
13 | pip install -e ".[dev]"
14 |
15 | # Production installation
16 | pip install its_hub
17 | ```
18 |
19 | ### Contribution
20 | When commit or raising PR, never mention it is by ClaudeCode.
21 | never say 🤖 Generated with [Claude Code](https://claude.ai/code)" in the commit statment, don't mention claude!
22 |
23 | ### Testing
24 | ```bash
25 | # Run all tests
26 | uv run pytest tests/
27 |
28 | # Run specific test file
29 | uv run pytest tests/test_algorithms.py
30 |
31 | # Run tests with coverage
32 | uv run pytest tests/ --cov=its_hub
33 |
34 | # Run tests with verbose output
35 | uv run pytest tests/ -v
36 | ```
37 |
38 | ### Code Quality
39 | ```bash
40 | # Run linter checks
41 | uv run ruff check its_hub/
42 |
43 | # Fix auto-fixable linting issues
44 | uv run ruff check its_hub/ --fix
45 |
46 | # Format code with ruff
47 | uv run ruff format its_hub/
48 | ```
49 |
50 | ### Git Workflow
51 | ```bash
52 | # Create commits with sign-off
53 | git commit -s -m "commit message"
54 |
55 | # For any git commits, always use the sign-off flag (-s)
56 | ```
57 |
58 | ### Running Examples
59 | ```bash
60 | # Test basic functionality
61 | python scripts/test_math_example.py
62 |
63 | # Benchmark algorithms (see script help for full options)
64 | python scripts/benchmark.py --help
65 | ```
66 |
67 | ### IaaS Service (Inference-as-a-Service)
68 | ```bash
69 | # Start IaaS service
70 | uv run its-iaas --host 0.0.0.0 --port 8108
71 |
72 | # Or using justfile (if available)
73 | just iaas-start
74 |
75 | # Check service health
76 | curl -s http://localhost:8108/v1/models | jq .
77 |
78 | # Configure the service (example: self-consistency algorithm)
79 | curl -X POST http://localhost:8108/configure \
80 | -H "Content-Type: application/json" \
81 | -d '{"endpoint": "http://localhost:8100/v1", "api_key": "NO_API_KEY", "model": "your-model-name", "alg": "self-consistency"}'
82 |
83 | # For comprehensive IaaS setup (multi-GPU, reward models, etc.), see docs/iaas-service.md
84 | ```
85 |
86 | ## Additional Tips
87 | - Use `rg` in favor of `grep` whenever it's available
88 | - Use `uv` for Python environment management: always start with `uv sync --extra dev` to init the env and run stuff with `uv run`
89 | - In case of dependency issues during testing, try commenting out `reward_hub` and `vllm` temporarily in @pyproject.toml and retry.
90 |
91 | ## Architecture Overview
92 |
93 | **its_hub** is a library for inference-time scaling of LLMs, focusing on mathematical reasoning tasks. The core architecture uses abstract base classes to define clean interfaces between components.
94 |
95 | ### Key Base Classes (`its_hub/base.py`)
96 | - `AbstractLanguageModel`: Interface for LM generation and evaluation
97 | - `AbstractScalingAlgorithm`: Base for all scaling algorithms with unified `infer()` method
98 | - `AbstractScalingResult`: Base for algorithm results with `the_one` property
99 | - `AbstractOutcomeRewardModel`: Interface for outcome-based reward models
100 | - `AbstractProcessRewardModel`: Interface for process-based reward models (step-by-step scoring)
101 |
102 | ### Main Components
103 |
104 | #### Language Models (`its_hub/lms.py`)
105 | - `OpenAICompatibleLanguageModel`: Primary LM implementation supporting vLLM and OpenAI APIs
106 | - `StepGeneration`: Handles incremental generation with configurable step tokens and stop conditions
107 | - Supports async generation with concurrency limits and backoff strategies
108 |
109 | #### Algorithms (`its_hub/algorithms/`)
110 | All algorithms follow the same interface: `infer(lm, prompt, budget, return_response_only=True)`
111 |
112 | - **Self-Consistency**: Generate multiple responses, select most common answer
113 | - **Best-of-N**: Generate N responses, select highest scoring via outcome reward model
114 | - **Beam Search**: Step-by-step generation with beam width, uses process reward models
115 | - **Particle Filtering/Gibbs**: Probabilistic resampling with process reward models
116 |
117 | #### Integration (`its_hub/integration/`)
118 | - `LocalVllmProcessRewardModel`: Integrates with reward_hub library for process-based scoring
119 | - `iaas.py`: Inference-as-a-Service FastAPI server providing OpenAI-compatible chat completions API with budget parameter for inference-time scaling
120 |
121 | ### Budget Interpretation
122 | The budget parameter controls computational resources allocated to each algorithm. Different algorithms interpret budget as follows:
123 | - **Self-Consistency/Best-of-N**: Number of parallel generations to create
124 | - **Beam Search**: Total generations divided by beam width (controls search depth)
125 | - **Particle Filtering**: Number of particles maintained during sampling
126 |
127 | ### Step Generation Pattern
128 | The `StepGeneration` class enables incremental text generation:
129 | - Configure step tokens (e.g., "\n\n" for reasoning steps)
130 | - Set max steps and stop conditions
131 | - Post-processing for clean output formatting
132 |
133 | ### Typical Workflow
134 | 1. Start vLLM server with instruction model
135 | 2. Initialize `OpenAICompatibleLanguageModel` pointing to server
136 | 3. Create `StepGeneration` with step/stop tokens appropriate for the task
137 | 4. Initialize reward model (e.g., `LocalVllmProcessRewardModel`)
138 | 5. Create scaling algorithm with step generation and reward model
139 | 6. Call `infer()` with prompt and budget
140 |
141 | ### Mathematical Focus
142 | The library is optimized for mathematical reasoning:
143 | - Predefined system prompts in `its_hub/utils.py` (SAL_STEP_BY_STEP_SYSTEM_PROMPT, QWEN_SYSTEM_PROMPT)
144 | - Regex patterns for mathematical notation (e.g., `r"\boxed"` for final answers)
145 | - Integration with math_verify for evaluation
146 | - Benchmarking on MATH500 and AIME-2024 datasets
147 |
148 | ## Inference-as-a-Service (IaaS)
149 |
150 | The its_hub library includes an IaaS service that provides OpenAI-compatible API with inference-time scaling capabilities. For comprehensive setup instructions, usage examples, and troubleshooting, see [docs/iaas-service.md](./docs/iaas-service.md).
--------------------------------------------------------------------------------
/its_hub/algorithms/beam_search.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | from pydantic.dataclasses import dataclass
5 |
6 | from its_hub.base import (
7 | AbstractLanguageModel,
8 | AbstractProcessRewardModel,
9 | AbstractScalingAlgorithm,
10 | AbstractScalingResult,
11 | )
12 | from its_hub.lms import StepGeneration
13 | from its_hub.types import ChatMessage, ChatMessages
14 |
15 |
16 | @dataclass
17 | class BeamSearchResult(AbstractScalingResult):
18 | responses: list[dict] # Keep original message format with tool calls
19 | scores: list[float]
20 | selected_index: int
21 | steps_used: list[int]
22 |
23 | @property
24 | def the_one(self) -> dict:
25 | return self.responses[self.selected_index]
26 |
27 |
28 | @dataclass
29 | class Path:
30 | steps: list[str]
31 | is_stopped: bool
32 | score: float
33 |
34 | def deepcopy(self):
35 | # create a deep copy of the path object
36 | return Path(
37 | steps=copy.deepcopy(self.steps),
38 | is_stopped=self.is_stopped,
39 | score=self.score,
40 | )
41 |
42 |
43 | class BeamSearch(AbstractScalingAlgorithm):
44 | def __init__(
45 | self,
46 | sg: StepGeneration,
47 | prm: AbstractProcessRewardModel,
48 | beam_width: int,
49 | ):
50 | self.sg = sg
51 | self.prm = prm
52 | self.beam_width = beam_width
53 |
54 | async def _asearch_one_level(
55 | self,
56 | lm: AbstractLanguageModel,
57 | candidates: list[Path],
58 | prompt: str,
59 | tools: list[dict] | None = None,
60 | tool_choice: str | dict | None = None,
61 | ) -> list[Path]:
62 | """search one level asynchronously"""
63 | is_stopped_in_the_beginning = [c.is_stopped for c in candidates]
64 |
65 | # collect batch inputs
66 | prompts, steps_so_far = [], []
67 | for c, is_stopped in zip(candidates, is_stopped_in_the_beginning):
68 | if is_stopped:
69 | continue
70 | prompts.append(prompt)
71 | steps_so_far.append(c.steps)
72 |
73 | # collect batch outputs
74 | sg_forward_results = await self.sg.aforward(
75 | lm, prompts, steps_so_far, tools=tools, tool_choice=tool_choice
76 | )
77 |
78 | # update candidates
79 | i = 0
80 | for c, is_stopped in zip(candidates, is_stopped_in_the_beginning):
81 | if is_stopped:
82 | continue
83 | next_step, is_stopped = sg_forward_results[i]
84 | c.steps.append(next_step)
85 | c.is_stopped = is_stopped
86 | i += 1
87 |
88 | # collect batch inputs for scoring
89 | steps_so_far = []
90 | for c, is_stopped in zip(candidates, is_stopped_in_the_beginning):
91 | if is_stopped:
92 | continue
93 | steps_so_far.append(c.steps)
94 |
95 | # collect batch outputs for scoring
96 | scores = await self.prm.ascore(
97 | prompt,
98 | [
99 | self.sg._post_process(steps_so_far_per_prompt, stopped=True)
100 | for steps_so_far_per_prompt in steps_so_far
101 | ],
102 | )
103 |
104 | # update candidates
105 | i = 0
106 | for c, is_stopped in zip(candidates, is_stopped_in_the_beginning):
107 | if is_stopped:
108 | continue
109 | c.score = scores[i]
110 | i += 1
111 |
112 | return candidates
113 |
114 | def _search_one_level(
115 | self,
116 | lm: AbstractLanguageModel,
117 | candidates: list[Path],
118 | prompt: str,
119 | tools: list[dict] | None = None,
120 | tool_choice: str | dict | None = None,
121 | ) -> list[Path]:
122 | """search one level synchronously"""
123 | import asyncio
124 |
125 | return asyncio.run(
126 | self._asearch_one_level(lm, candidates, prompt, tools, tool_choice)
127 | )
128 |
129 | async def ainfer(
130 | self,
131 | lm: AbstractLanguageModel,
132 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
133 | budget: int,
134 | return_response_only: bool = True,
135 | tools: list[dict] | None = None,
136 | tool_choice: str | dict | None = None,
137 | ) -> dict | BeamSearchResult:
138 | """run inference asynchronously with beam search"""
139 | chat_messages = ChatMessages.from_prompt_or_messages(prompt_or_messages)
140 | assert budget % self.beam_width == 0, "budget must be divisible by beam_width"
141 | assert budget >= self.beam_width, (
142 | "budget must be greater than or equal to beam_width"
143 | )
144 |
145 | num_beams = budget // self.beam_width
146 |
147 | candidates = [
148 | Path(steps=[], is_stopped=False, score=0) for _ in range(num_beams)
149 | ]
150 |
151 | while not all(c.is_stopped for c in candidates):
152 | # TODO: Update _asearch_one_level to support native ChatMessages format instead of string conversion
153 | candidates = await self._asearch_one_level(
154 | lm,
155 | candidates,
156 | chat_messages.to_prompt(),
157 | tools=tools,
158 | tool_choice=tool_choice,
159 | )
160 |
161 | # get the top beam_width candidates
162 | candidates.sort(key=lambda x: x.score, reverse=True)
163 | candidates = candidates[: self.beam_width]
164 |
165 | # duplicate the candidates with the highest score
166 | new_candidates = []
167 | for _ in range(num_beams):
168 | for c in candidates:
169 | new_candidates.append(c.deepcopy())
170 | candidates = new_candidates
171 |
172 | scores = [c.score for c in candidates]
173 | steps_used = [len(c.steps) for c in candidates]
174 | result = BeamSearchResult(
175 | responses=[
176 | {
177 | "role": "assistant",
178 | "content": self.sg._post_process(c.steps, stopped=True),
179 | }
180 | for c in candidates
181 | ],
182 | scores=scores,
183 | selected_index=int(np.argmax(scores)),
184 | steps_used=steps_used,
185 | )
186 | return result.the_one if return_response_only else result
187 |
--------------------------------------------------------------------------------
/its_hub/error_handling.py:
--------------------------------------------------------------------------------
1 | """
2 | Error handling utilities for OpenAI-compatible API calls.
3 |
4 | This module provides specific exception classes and utilities to handle
5 | different types of API failures with appropriate retry logic and informative
6 | error messages.
7 | """
8 |
9 | import json
10 | from typing import Any
11 |
12 |
13 | class APIError(Exception):
14 | """Base class for API-related errors."""
15 |
16 | def __init__(
17 | self,
18 | message: str,
19 | status_code: int | None = None,
20 | error_details: dict[str, Any] | None = None,
21 | ):
22 | self.message = message
23 | self.status_code = status_code
24 | self.error_details = error_details or {}
25 | super().__init__(message)
26 |
27 |
28 | class RateLimitError(APIError):
29 | """Rate limit exceeded - retryable."""
30 |
31 | pass
32 |
33 |
34 | class ContextLengthError(APIError):
35 | """Context length exceeded - not retryable."""
36 |
37 | pass
38 |
39 |
40 | class AuthenticationError(APIError):
41 | """Authentication failed - not retryable."""
42 |
43 | pass
44 |
45 |
46 | class APIConnectionError(APIError):
47 | """Network/connection issues - retryable."""
48 |
49 | pass
50 |
51 |
52 | class BadRequestError(APIError):
53 | """Bad request (invalid parameters) - not retryable."""
54 |
55 | pass
56 |
57 |
58 | class InternalServerError(APIError):
59 | """Server error - retryable."""
60 |
61 | pass
62 |
63 |
64 | # Retryable error types
65 | RETRYABLE_ERRORS = (RateLimitError, APIConnectionError, InternalServerError)
66 |
67 |
68 | def parse_api_error(status_code: int, error_text: str) -> APIError:
69 | """
70 | Parse API error response and return appropriate exception.
71 |
72 | Args:
73 | status_code: HTTP status code
74 | error_text: Raw error response text
75 |
76 | Returns:
77 | Appropriate APIError subclass instance
78 | """
79 | error_details = {}
80 | error_message = error_text
81 |
82 | # Try to parse JSON error response
83 | try:
84 | error_json = json.loads(error_text)
85 | if "error" in error_json:
86 | error_info = error_json["error"]
87 | if isinstance(error_info, dict):
88 | error_message = error_info.get("message", error_text)
89 | error_details = error_info
90 | else:
91 | error_message = str(error_info)
92 | except (json.JSONDecodeError, KeyError):
93 | # Use raw text if JSON parsing fails
94 | pass
95 |
96 | # Classify error based on status code and message content
97 | error_message_lower = error_message.lower()
98 |
99 | if status_code == 429 or "rate limit" in error_message_lower:
100 | return RateLimitError(
101 | f"Rate limit exceeded: {error_message}",
102 | status_code=status_code,
103 | error_details=error_details,
104 | )
105 |
106 | if status_code == 400 and any(
107 | phrase in error_message_lower
108 | for phrase in [
109 | "context length",
110 | "maximum context",
111 | "too long",
112 | "token limit",
113 | "context_length_exceeded",
114 | "max_tokens",
115 | ]
116 | ):
117 | return ContextLengthError(
118 | f"Context length exceeded: {error_message}",
119 | status_code=status_code,
120 | error_details=error_details,
121 | )
122 |
123 | if (
124 | status_code in [401, 403]
125 | or "authentication" in error_message_lower
126 | or "unauthorized" in error_message_lower
127 | ):
128 | return AuthenticationError(
129 | f"Authentication failed: {error_message}",
130 | status_code=status_code,
131 | error_details=error_details,
132 | )
133 |
134 | if status_code == 400:
135 | return BadRequestError(
136 | f"Bad request: {error_message}",
137 | status_code=status_code,
138 | error_details=error_details,
139 | )
140 |
141 | if status_code >= 500:
142 | return InternalServerError(
143 | f"Server error: {error_message}",
144 | status_code=status_code,
145 | error_details=error_details,
146 | )
147 |
148 | # Default to APIConnectionError for other cases (network issues, etc.)
149 | return APIConnectionError(
150 | f"API request failed: {error_message}",
151 | status_code=status_code,
152 | error_details=error_details,
153 | )
154 |
155 |
156 | def enhanced_on_backoff(details):
157 | """
158 | Enhanced backoff callback that shows specific error information.
159 |
160 | Args:
161 | details: Backoff details dictionary containing exception info
162 | """
163 | exception = details.get("exception")
164 | if isinstance(exception, APIError):
165 | error_type = type(exception).__name__
166 | print(
167 | f"Retrying after {details['wait']:.1f}s (attempt {details['tries']}) - "
168 | f"{error_type}: {exception.message}"
169 | )
170 | else:
171 | # Fallback for non-APIError exceptions
172 | print(
173 | f"Retrying after {details['wait']:.1f}s (attempt {details['tries']}) - "
174 | f"Error: {exception!s}"
175 | )
176 |
177 |
178 | def should_retry(exception: Exception) -> bool:
179 | """
180 | Determine if an exception should be retried.
181 |
182 | Args:
183 | exception: The exception to check
184 |
185 | Returns:
186 | True if the exception should be retried, False otherwise
187 | """
188 | return isinstance(exception, RETRYABLE_ERRORS)
189 |
190 |
191 | def format_non_retryable_error(exception: APIError) -> str:
192 | """
193 | Format a helpful message for non-retryable errors.
194 |
195 | Args:
196 | exception: The non-retryable APIError
197 |
198 | Returns:
199 | Formatted error message with suggestions
200 | """
201 | if isinstance(exception, ContextLengthError):
202 | return (
203 | f"❌ {exception.message}\n"
204 | f"💡 Suggestion: Reduce input length, increase max_tokens, or use a model with larger context window"
205 | )
206 |
207 | if isinstance(exception, AuthenticationError):
208 | return (
209 | f"❌ {exception.message}\n"
210 | f"💡 Suggestion: Check your API key and endpoint configuration"
211 | )
212 |
213 | if isinstance(exception, BadRequestError):
214 | return (
215 | f"❌ {exception.message}\n"
216 | f"💡 Suggestion: Check your request parameters (temperature, max_tokens, etc.)"
217 | )
218 |
219 | return f"❌ {exception.message}"
220 |
--------------------------------------------------------------------------------
/docs/quick-start.md:
--------------------------------------------------------------------------------
1 | # Quick Start Guide
2 |
3 | This guide shows examples of inference-time scaling. **Tool calling is the primary use case** for production applications.
4 |
5 | ## Example 1: Self-Consistency with Tool Calling (Recommended)
6 |
7 | **Installation required:** `pip install its_hub`
8 |
9 | This example shows how to use Self-Consistency for reliable tool calling in agent applications.
10 |
11 | ```python
12 | from its_hub.lms import OpenAICompatibleLanguageModel
13 | from its_hub.algorithms import SelfConsistency
14 | from its_hub.types import ChatMessage, ChatMessages
15 |
16 | # Initialize language model
17 | lm = OpenAICompatibleLanguageModel(
18 | endpoint="https://api.openai.com/v1",
19 | api_key="your-api-key",
20 | model_name="gpt-4o-mini"
21 | )
22 |
23 | # Define tools (OpenAI format)
24 | tools = [
25 | {
26 | "type": "function",
27 | "function": {
28 | "name": "calculator",
29 | "description": "Perform arithmetic calculations",
30 | "parameters": {
31 | "type": "object",
32 | "properties": {
33 | "expression": {
34 | "type": "string",
35 | "description": "Mathematical expression to evaluate"
36 | }
37 | },
38 | "required": ["expression"]
39 | }
40 | }
41 | }
42 | ]
43 |
44 | # Create messages
45 | messages = ChatMessages([
46 | ChatMessage(
47 | role="system",
48 | content="You are a precise calculator. Always use the calculator tool for arithmetic."
49 | ),
50 | ChatMessage(
51 | role="user",
52 | content="What is 847 * 293 + 156?"
53 | )
54 | ])
55 |
56 | # Use hierarchical tool voting
57 | sc = SelfConsistency(tool_vote="tool_hierarchical")
58 | result = sc.infer(
59 | lm,
60 | messages,
61 | budget=5,
62 | tools=tools,
63 | tool_choice="auto"
64 | )
65 | print(result)
66 | ```
67 |
68 | **What happens:**
69 | 1. Generates 5 different responses with tool calls
70 | 2. Votes on tool names first (which tool to use)
71 | 3. Then votes on tool arguments (what parameters to pass)
72 | 4. Returns the most consistent tool call
73 |
74 | ---
75 |
76 | ## Example 2: Best-of-N with LLM Judge (Core Installation)
77 |
78 | **Installation required:** `pip install its_hub`
79 |
80 | This example uses Best-of-N algorithm with an LLM judge for response selection. Works with any OpenAI-compatible API and requires no GPU.
81 |
82 | ```python
83 | from its_hub.lms import OpenAICompatibleLanguageModel
84 | from its_hub.algorithms import BestOfN
85 | from its_hub.integration.reward_hub import LLMJudgeRewardModel
86 |
87 | # Initialize language model
88 | lm = OpenAICompatibleLanguageModel(
89 | endpoint="https://api.openai.com/v1",
90 | api_key="your-api-key",
91 | model_name="gpt-4o-mini",
92 | )
93 |
94 | # Set up LLM judge for scoring
95 | judge = LLMJudgeRewardModel(
96 | model="gpt-4o-mini",
97 | criterion="overall_quality",
98 | judge_type="groupwise",
99 | api_key="your-api-key",
100 | )
101 | scaling_alg = BestOfN(judge)
102 |
103 | # Generate multiple responses and select the best
104 | result = scaling_alg.infer(
105 | lm,
106 | "Explain quantum entanglement in simple terms",
107 | budget=4
108 | )
109 | print(result)
110 | ```
111 |
112 | **What happens:**
113 | 1. Generates 4 different responses to the prompt
114 | 2. LLM judge scores all responses
115 | 3. Returns the highest-scoring response
116 |
117 | ---
118 |
119 | ## Example 3: Particle Filtering with Process Reward Model
120 |
121 | **Installation required:** `pip install its_hub[prm]`
122 |
123 | This example uses Particle Filtering for step-by-step mathematical reasoning with a local process reward model. Requires GPU.
124 |
125 | ### Prerequisites
126 |
127 | - GPU with CUDA 11.8+
128 | - 20GB+ GPU memory recommended (for 7B reward model)
129 |
130 | ### Step 1: Start vLLM Server
131 |
132 | Start a vLLM server with your instruction model:
133 |
134 | ```bash
135 | CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-Math-1.5B-Instruct \
136 | --dtype float16 \
137 | --port 8100 \
138 | --max-model-len 4096 \
139 | --gpu-memory-utilization 0.7
140 | ```
141 |
142 | ### Step 2: Run Particle Filtering
143 |
144 | ```python
145 | from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT
146 | from its_hub.lms import OpenAICompatibleLanguageModel, StepGeneration
147 | from its_hub.algorithms import ParticleFiltering
148 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
149 |
150 | # Initialize language model (points to vLLM server)
151 | lm = OpenAICompatibleLanguageModel(
152 | endpoint="http://localhost:8100/v1",
153 | api_key="NO_API_KEY",
154 | model_name="Qwen/Qwen2.5-Math-1.5B-Instruct",
155 | system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT,
156 | )
157 |
158 | # Set up step generation and process reward model
159 | sg = StepGeneration(step_token="\n\n", max_steps=32, stop_token=r"\boxed")
160 | prm = LocalVllmProcessRewardModel(
161 | model_name="Qwen/Qwen2.5-Math-PRM-7B",
162 | device="cuda:0",
163 | aggregation_method="prod"
164 | )
165 | scaling_alg = ParticleFiltering(sg, prm)
166 |
167 | # Solve with step-by-step reasoning
168 | result = scaling_alg.infer(lm, "Solve x^2 + 5x + 6 = 0", budget=8)
169 | print(result)
170 | ```
171 |
172 | **What happens:**
173 | 1. Generates reasoning steps incrementally (separated by `\n\n`)
174 | 2. Process reward model scores each step
175 | 3. Maintains 8 particles, resampling based on scores
176 | 4. Returns best solution when `\boxed` pattern is found
177 |
178 | ### Memory Requirements
179 |
180 | - **Instruction Model** (Qwen2.5-Math-1.5B): ~3GB GPU memory
181 | - **Reward Model** (Qwen2.5-Math-PRM-7B): ~14GB GPU memory
182 | - **Total**: 20GB+ recommended (H100, A100, or similar)
183 |
184 | ---
185 |
186 | ## Troubleshooting
187 |
188 | ### Examples 1 & 2 (Cloud APIs)
189 |
190 | **API errors**: Verify API key and endpoint are correct
191 | **Slow responses**: Reduce `budget` parameter (e.g., from 5 to 2)
192 |
193 | ### Example 3 (Particle Filtering)
194 |
195 | **CUDA Out of Memory**:
196 | - Reduce `--gpu-memory-utilization` to 0.6 or lower
197 | - Reduce `--max-num-seqs` to 64
198 | - Ensure no other processes are using GPU
199 | - Check memory: `nvidia-smi`
200 |
201 | **Server Connection Issues**:
202 | ```bash
203 | # Verify vLLM server is running
204 | curl http://localhost:8100/v1/models
205 | ```
206 |
207 | **Model Loading Issues**:
208 | - Ensure sufficient disk space (~20GB for both models)
209 | - Check internet connection for model downloads
210 | - Verify model names are correct
211 |
212 | ---
213 |
214 | ## Next Steps
215 |
216 | - **Explore algorithms**: See [Algorithms](algorithms.md) for Beam Search, Self-Consistency, and other approaches
217 | - **Deploy as API**: See [IaaS Service Guide](iaas-service.md) to deploy as OpenAI-compatible service
218 | - **Contribute**: See [Development](development.md) for contribution guidelines
--------------------------------------------------------------------------------
/docs/benchmarking.md:
--------------------------------------------------------------------------------
1 | # Benchmarking
2 |
3 | its-hub includes comprehensive benchmarking tools to evaluate inference-time scaling algorithms on standard mathematical reasoning datasets.
4 |
5 | ## Quick Start
6 |
7 | ```bash
8 | python scripts/benchmark.py --help
9 | ```
10 |
11 | Example benchmark command:
12 | ```bash
13 | python scripts/benchmark.py \
14 | --benchmark aime-2024 \
15 | --model_name Qwen/Qwen2.5-Math-1.5B-Instruct \
16 | --alg particle-filtering \
17 | --rm_device cuda:1 \
18 | --endpoint http://0.0.0.0:8000/v1 \
19 | --shuffle_seed 1110 \
20 | --does_eval \
21 | --budgets 1,2,4,8,16,32,64 \
22 | --rm_agg_method model
23 | ```
24 |
25 | ## Supported Datasets
26 |
27 | ### MATH500
28 | A subset of 500 problems from the MATH dataset, covering various mathematical topics.
29 |
30 | ```bash
31 | python scripts/benchmark.py --benchmark math-500 --model_name Qwen/Qwen2.5-Math-1.5B-Instruct
32 | ```
33 |
34 | ### AIME-2024
35 | American Invitational Mathematics Examination problems from 2024.
36 |
37 | ```bash
38 | python scripts/benchmark.py --benchmark aime-2024 --model_name Qwen/Qwen2.5-Math-1.5B-Instruct
39 | ```
40 |
41 | ## Algorithm Comparison
42 |
43 | ### Benchmarking Multiple Algorithms
44 |
45 | ```bash
46 | # Compare all algorithms on MATH500
47 | for alg in self-consistency best-of-n beam-search particle-filtering; do
48 | python scripts/benchmark.py \
49 | --benchmark math-500 \
50 | --model_name Qwen/Qwen2.5-Math-1.5B-Instruct \
51 | --alg $alg \
52 | --budgets 1,2,4,8,16 \
53 | --does_eval
54 | done
55 | ```
56 |
57 | ### Budget Scaling Analysis
58 |
59 | ```bash
60 | # Analyze performance vs computational budget
61 | python scripts/benchmark.py \
62 | --benchmark math-500 \
63 | --model_name Qwen/Qwen2.5-Math-1.5B-Instruct \
64 | --alg particle-filtering \
65 | --budgets 1,2,4,8,16,32,64,128 \
66 | --does_eval
67 | ```
68 |
69 | ## Configuration Options
70 |
71 | ### Basic Parameters
72 |
73 | - `--benchmark`: Dataset to use (`math-500`, `aime-2024`)
74 | - `--model_name`: Model identifier (e.g., `Qwen/Qwen2.5-Math-1.5B-Instruct`)
75 | - `--alg`: Algorithm to benchmark (`self-consistency`, `best-of-n`, `beam-search`, `particle-filtering`)
76 | - `--budgets`: Comma-separated list of budget values
77 | - `--endpoint`: API endpoint for model inference
78 | - `--does_eval`: Enable automatic evaluation of results
79 |
80 | ### Advanced Parameters
81 |
82 | - `--shuffle_seed`: Seed for shuffling problems (reproducibility)
83 | - `--rm_device`: GPU device for reward model (e.g., `cuda:0`)
84 | - `--rm_agg_method`: Reward aggregation method (`prod`, `mean`, `model`)
85 | - `--beam_width`: Beam width for beam search (default: 4)
86 | - `--max_steps`: Maximum steps for step-by-step algorithms
87 | - `--step_token`: Token for step boundaries (default: `\\n\\n`)
88 | - `--stop_pattern`: Regex pattern for stopping generation
89 |
90 | ## Output Format
91 |
92 | ### Results Structure
93 |
94 | The benchmark script generates detailed results including:
95 |
96 | ```json
97 | {
98 | "algorithm": "particle-filtering",
99 | "dataset": "math-500",
100 | "model": "Qwen/Qwen2.5-Math-1.5B-Instruct",
101 | "budget": 8,
102 | "accuracy": 0.756,
103 | "total_problems": 500,
104 | "correct_answers": 378,
105 | "average_response_time": 12.34,
106 | "detailed_results": [
107 | {
108 | "problem_id": "001",
109 | "problem": "Solve x^2 + 5x + 6 = 0",
110 | "correct_answer": "x = -2, -3",
111 | "model_response": "...",
112 | "is_correct": true,
113 | "response_time": 8.21
114 | }
115 | ]
116 | }
117 | ```
118 |
119 | ### Evaluation Metrics
120 |
121 | - **Accuracy**: Percentage of correctly solved problems
122 | - **Response Time**: Average time per problem (seconds)
123 | - **Budget Efficiency**: Accuracy improvement per unit budget
124 | - **Error Analysis**: Breakdown of error types and frequencies
125 |
126 | ## Performance Analysis
127 |
128 | ### Plotting Results
129 |
130 | ```python
131 | import matplotlib.pyplot as plt
132 | import json
133 |
134 | # Load benchmark results
135 | with open('benchmark_results.json', 'r') as f:
136 | results = json.load(f)
137 |
138 | # Plot accuracy vs budget
139 | budgets = [r['budget'] for r in results]
140 | accuracies = [r['accuracy'] for r in results]
141 |
142 | plt.figure(figsize=(10, 6))
143 | plt.plot(budgets, accuracies, 'o-')
144 | plt.xlabel('Budget')
145 | plt.ylabel('Accuracy')
146 | plt.title('Accuracy vs Computational Budget')
147 | plt.grid(True)
148 | plt.show()
149 | ```
150 |
151 | ### Statistical Analysis
152 |
153 | ```python
154 | import numpy as np
155 | from scipy import stats
156 |
157 | # Compare two algorithms
158 | results_a = [r for r in results if r['algorithm'] == 'self-consistency']
159 | results_b = [r for r in results if r['algorithm'] == 'particle-filtering']
160 |
161 | accuracies_a = [r['accuracy'] for r in results_a]
162 | accuracies_b = [r['accuracy'] for r in results_b]
163 |
164 | # Perform t-test
165 | t_stat, p_value = stats.ttest_ind(accuracies_a, accuracies_b)
166 | print(f"T-test p-value: {p_value}")
167 | ```
168 |
169 | ## Custom Benchmarks
170 |
171 | ### Adding New Datasets
172 |
173 | ```python
174 | # Create custom dataset
175 | custom_problems = [
176 | {
177 | "id": "custom_001",
178 | "problem": "Your math problem here",
179 | "answer": "Expected answer",
180 | "category": "algebra"
181 | }
182 | ]
183 |
184 | # Save as JSON
185 | import json
186 | with open('custom_benchmark.json', 'w') as f:
187 | json.dump(custom_problems, f)
188 | ```
189 |
190 | ### Custom Evaluation Metrics
191 |
192 | ```python
193 | def custom_evaluator(predicted_answer, correct_answer):
194 | """Custom evaluation function"""
195 | # Implement your evaluation logic
196 | return predicted_answer.strip().lower() == correct_answer.strip().lower()
197 |
198 | # Use in benchmark script
199 | python scripts/benchmark.py \
200 | --benchmark custom_benchmark.json \
201 | --custom_evaluator custom_evaluator
202 | ```
203 |
204 | ## Best Practices
205 |
206 | ### Reproducibility
207 |
208 | 1. **Set Random Seeds**: Use `--shuffle_seed` for consistent problem ordering
209 | 2. **Fixed Hyperparameters**: Document all configuration options
210 | 3. **Environment Tracking**: Record GPU type, driver versions, and dependencies
211 |
212 | ### Performance Optimization
213 |
214 | 1. **GPU Memory Management**: Monitor memory usage during benchmarks
215 | 2. **Batch Processing**: Use appropriate batch sizes for your hardware
216 | 3. **Caching**: Enable model caching for faster repeated evaluations
217 |
218 | ### Result Validation
219 |
220 | 1. **Cross-Validation**: Run multiple seeds and average results
221 | 2. **Significance Testing**: Use statistical tests to validate improvements
222 | 3. **Human Evaluation**: Manually verify a sample of results
223 |
224 | ## Troubleshooting
225 |
226 | ### Common Issues
227 |
228 | **Out of Memory Errors:**
229 | ```bash
230 | # Reduce batch size or budget
231 | python scripts/benchmark.py --budgets 1,2,4 --rm_device cuda:0
232 | ```
233 |
234 | **Slow Evaluation:**
235 | ```bash
236 | # Disable evaluation for faster benchmarking
237 | python scripts/benchmark.py --no_eval
238 | ```
239 |
240 | **Model Loading Issues:**
241 | ```bash
242 | # Verify model availability
243 | curl http://localhost:8000/v1/models
244 | ```
245 |
246 | ### Performance Monitoring
247 |
248 | ```bash
249 | # Monitor GPU usage during benchmarking
250 | watch -n 1 nvidia-smi
251 |
252 | # Monitor system resources
253 | htop
254 | ```
--------------------------------------------------------------------------------
/tests/test_particle_gibbs_resampling.py:
--------------------------------------------------------------------------------
1 | """Test for the particle Gibbs resampling weight calculation fix (issue #54)."""
2 |
3 | import random
4 |
5 | import numpy as np
6 |
7 | from its_hub.algorithms.particle_gibbs import (
8 | Particle,
9 | ParticleGibbs,
10 | ParticleGibbsResult,
11 | SelectionMethod,
12 | )
13 | from its_hub.base import AbstractLanguageModel, AbstractProcessRewardModel
14 | from its_hub.lms import StepGeneration
15 |
16 |
17 | class MockLanguageModelForResampling(AbstractLanguageModel):
18 | """Mock LM that generates predictable steps for testing resampling."""
19 |
20 | def __init__(self):
21 | self.step_counter = 0
22 |
23 | async def agenerate(self, messages, **kwargs):
24 | return self.generate(messages, **kwargs)
25 |
26 | def generate(self, messages, max_tokens=100, **kwargs):
27 | # Handle both single and batch calls like OpenAICompatibleLanguageModel
28 | if (
29 | isinstance(messages, list)
30 | and len(messages) > 0
31 | and isinstance(messages[0], list)
32 | ):
33 | # Batch generation
34 | results = []
35 | for _ in messages:
36 | step = f"step{self.step_counter}"
37 | self.step_counter += 1
38 | results.append({"role": "assistant", "content": step})
39 | return results
40 | else:
41 | # Single generation
42 | step = f"step{self.step_counter}"
43 | self.step_counter += 1
44 | return {"role": "assistant", "content": step}
45 |
46 | async def aevaluate(self, prompt, response):
47 | return self.evaluate(prompt, response)
48 |
49 | def evaluate(self, prompt, response):
50 | # Not used in these tests
51 | return 0.5
52 |
53 |
54 | class MockProcessRewardModelForResampling(AbstractProcessRewardModel):
55 | """Mock PRM that gives higher scores to longer sequences."""
56 |
57 | async def ascore(self, prompt, response):
58 | return self.score(prompt, response)
59 |
60 | def score(self, prompt, response):
61 | if isinstance(response, list):
62 | # Batch scoring
63 | return [self._score_single(r) for r in response]
64 | else:
65 | # Single scoring
66 | return self._score_single(response)
67 |
68 | def _score_single(self, response):
69 | # Give higher scores to longer responses
70 | # This simulates a scenario where reference particles (being longer)
71 | # would have unfairly high scores if we don't use partial weights
72 | num_steps = response.count("step")
73 | # Return a score between 0.5 and 0.9 based on length
74 | return min(0.5 + 0.1 * num_steps, 0.9)
75 |
76 |
77 | class TestParticleGibbsResampling:
78 | """Test the fix for issue #54 - proper calculation of resampling weights."""
79 |
80 | def test_reference_trajectory_partial_weights(self):
81 | """Test that reference trajectories use partial weights during resampling."""
82 | # Create mock models
83 | mock_lm = MockLanguageModelForResampling()
84 | mock_prm = MockProcessRewardModelForResampling()
85 |
86 | # Create step generation with 3 max steps
87 | sg = StepGeneration(step_token="\n", max_steps=3)
88 |
89 | # Create ParticleGibbs with 2 iterations and 1 reference particle
90 | pg = ParticleGibbs(
91 | sg=sg,
92 | prm=mock_prm,
93 | num_iterations=2,
94 | selection_method=SelectionMethod.ARGMAX,
95 | num_ref_particles=1,
96 | )
97 |
98 | # Run inference with budget=4 (2 particles per iteration)
99 | result = pg.infer(mock_lm, "Test prompt", budget=4, return_response_only=False)
100 |
101 | # Verify the result structure
102 | assert isinstance(result, ParticleGibbsResult)
103 | assert len(result.responses_lst) == 2 # 2 iterations
104 | assert len(result.log_weights_lst) == 2
105 | assert len(result.ref_indices_lst) == 2
106 |
107 | # In the second iteration, check that particles have consistent trajectory lengths
108 | # after resampling (this would fail with the old implementation)
109 | second_iter_steps = result.steps_used_lst[1]
110 |
111 | # All particles should have progressed similarly
112 | # With the bug, reference particle would keep its full trajectory
113 | # making it have more steps than others
114 | assert len(set(second_iter_steps)) <= 2, (
115 | f"Particles have very different step counts: {second_iter_steps}. "
116 | "This suggests reference particles weren't truncated properly."
117 | )
118 |
119 | def test_resampling_weights_consistency(self):
120 | """Test that resampling uses consistent weights across particles."""
121 | # Create a custom mock that allows us to inspect the resampling process
122 | mock_lm = MockLanguageModelForResampling()
123 | mock_prm = MockProcessRewardModelForResampling()
124 |
125 | sg = StepGeneration(step_token="\n", max_steps=5)
126 |
127 | # Create a test scenario with manual particle manipulation
128 | # Initialize particles with different trajectories
129 | ref_particle = Particle(
130 | steps=["ref_step1", "ref_step2", "ref_step3"],
131 | is_stopped=True,
132 | partial_log_weights=[0.6, 0.7, 0.8],
133 | )
134 |
135 | new_particle = Particle(
136 | steps=["new_step1"], is_stopped=False, partial_log_weights=[0.65]
137 | )
138 |
139 | # Test the _propagate method
140 | pg = ParticleGibbs(
141 | sg=sg,
142 | prm=mock_prm,
143 | num_iterations=1,
144 | selection_method=SelectionMethod.ARGMAX,
145 | num_ref_particles=1,
146 | )
147 |
148 | # Propagate one step
149 | particles = [new_particle, ref_particle]
150 | propagated = pg._propagate(mock_lm, particles, "Test prompt")
151 |
152 | # After propagation, only non-stopped particles get extended
153 | assert (
154 | len(propagated[0].partial_log_weights) == 2
155 | ) # new particle now has 2 steps
156 | assert (
157 | len(propagated[1].partial_log_weights) == 3
158 | ) # ref particle stays at 3 steps (stopped)
159 |
160 | # The key insight: during resampling in the main loop,
161 | # we should compare partial_log_weights[1] for both particles,
162 | # not the full log_weight of the reference particle
163 |
164 | def test_reference_particle_truncation(self):
165 | """Test that reference particles are properly truncated when resampled."""
166 | mock_lm = MockLanguageModelForResampling()
167 | mock_prm = MockProcessRewardModelForResampling()
168 |
169 | sg = StepGeneration(step_token="\n", max_steps=4)
170 |
171 | # Use a controlled random seed for reproducibility
172 | random.seed(42)
173 | np.random.seed(42)
174 |
175 | pg = ParticleGibbs(
176 | sg=sg,
177 | prm=mock_prm,
178 | num_iterations=2,
179 | selection_method=SelectionMethod.ARGMAX,
180 | num_ref_particles=1,
181 | )
182 |
183 | result = pg.infer(mock_lm, "Test prompt", budget=6, return_response_only=False)
184 |
185 | # Check that in the second iteration, particles don't have
186 | # drastically different numbers of steps
187 | second_iter_steps = result.steps_used_lst[1]
188 |
189 | # The maximum difference in steps should be reasonable
190 | # (not the full trajectory length difference)
191 | max_diff = max(second_iter_steps) - min(second_iter_steps)
192 | assert max_diff <= 2, (
193 | f"Large step difference in second iteration: {second_iter_steps}. "
194 | "Reference particle may not have been truncated."
195 | )
196 |
--------------------------------------------------------------------------------
/notebooks/self-consistency.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "98e1f27d",
6 | "metadata": {},
7 | "source": [
8 | "# Self-Consistency Algorithm Demo\n",
9 | "This notebook demonstrates the Self-Consistency algorithm for mathematical reasoning."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "id": "be6eaec0",
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%load_ext autoreload\n",
20 | "%autoreload 2"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "id": "b6e4fabc",
27 | "metadata": {
28 | "lines_to_next_cell": 0
29 | },
30 | "outputs": [],
31 | "source": [
32 | "import os\n",
33 | "\n",
34 | "import nest_asyncio\n",
35 | "from dotenv import load_dotenv\n",
36 | "\n",
37 | "from its_hub.lms import OpenAICompatibleLanguageModel\n",
38 | "from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT\n",
39 | "\n",
40 | "nest_asyncio.apply()\n",
41 | "\n",
42 | "# Load environment variables from .env file\n",
43 | "load_dotenv()\n",
44 | "\n",
45 | "# Main example: OpenAI API endpoint with gpt-4o-mini\n",
46 | "lm = OpenAICompatibleLanguageModel(\n",
47 | " endpoint=\"https://api.openai.com/v1\",\n",
48 | " api_key=os.getenv(\"OPENAI_API_KEY\"), # Load API key from environment\n",
49 | " model_name=\"gpt-4o-mini\",\n",
50 | " system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT,\n",
51 | " is_async=True,\n",
52 | ")"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "id": "ed10ced9",
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "# Alternative: vLLM local endpoint (commented out)\n",
63 | "# lm = OpenAICompatibleLanguageModel(\n",
64 | "# endpoint=\"http://localhost:8000/v1\",\n",
65 | "# api_key=\"NO_API_KEY\",\n",
66 | "# model_name=\"qwen2-math-1.5b-instruct\",\n",
67 | "# system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT,\n",
68 | "# is_async=True,\n",
69 | "# )"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "id": "35d3f4a5",
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "# Mathematical problem to solve\n",
80 | "prompt = r\"Let $a$ be a positive real number such that all the roots of \\[x^3 + ax^2 + ax + 1 = 0\\]are real. Find the smallest possible value of $a.$\"\n",
81 | "\n",
82 | "# Generate response using the proper format\n",
83 | "from its_hub.types import ChatMessages\n",
84 | "\n",
85 | "chat_messages = ChatMessages.from_prompt_or_messages(prompt)\n",
86 | "response = lm.generate(chat_messages.to_batch(1))[0]\n",
87 | "\n",
88 | "print(response)"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "id": "53386b1e",
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "def extract_boxed(s: str) -> str:\n",
99 | " import re\n",
100 | " # find all occurrences of \\boxed{...}\n",
101 | " boxed_matches = re.findall(r'\\\\boxed\\{([^{}]+(?:\\{[^{}]*\\}[^{}]*)*)\\}', s)\n",
102 | " # return the last match if any were found\n",
103 | " return boxed_matches[-1] if boxed_matches else \"\"\n",
104 | "\n",
105 | "print(extract_boxed(response['content']))"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "id": "7bd4a72e",
111 | "metadata": {},
112 | "source": [
113 | "## Self-Consistency Algorithm\n",
114 | "Now we'll use the Self-Consistency algorithm to improve the answer quality."
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "id": "1a3ab056",
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "from its_hub.algorithms import SelfConsistency\n",
125 | "\n",
126 | "# Set computational budget for scaling\n",
127 | "budget = 4\n",
128 | "\n",
129 | "scaling_alg = SelfConsistency(extract_boxed)\n",
130 | "\n",
131 | "scaling_result = scaling_alg.infer(\n",
132 | " lm, prompt, budget, return_response_only=False\n",
133 | ")\n",
134 | "\n",
135 | "print(\"######## Self-Consistency Result ########\")\n",
136 | "print(scaling_result.the_one)"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "id": "a12470e4",
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "print(\"######## Extracted Response Counts ########\")\n",
147 | "print(scaling_result.response_counts)"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "id": "1b960ab8",
154 | "metadata": {
155 | "lines_to_next_cell": 2
156 | },
157 | "outputs": [],
158 | "source": []
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "id": "97d69191",
163 | "metadata": {},
164 | "source": [
165 | "## Self-Consistency Algorithm for Tool Calls\n",
166 | "We have hierarchical tool-voting support in Self-Consistency algorithm\n",
167 | "It first votes on tool names, and then on tool arguments."
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "id": "b02c7de2",
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "from its_hub.types import ChatMessage, ChatMessages\n",
178 | "\n",
179 | "# Tool schema (OpenAI-style dicts)\n",
180 | "tools = [\n",
181 | " {\n",
182 | " \"type\": \"function\",\n",
183 | " \"function\": {\n",
184 | " \"name\": \"calculator\",\n",
185 | " \"description\": \"Perform arithmetic calculations\",\n",
186 | " \"parameters\": {\n",
187 | " \"type\": \"object\",\n",
188 | " \"properties\": {\n",
189 | " \"expression\": {\n",
190 | " \"type\": \"string\",\n",
191 | " \"description\": \"Mathematical expression to evaluate\"\n",
192 | " }\n",
193 | " },\n",
194 | " \"required\": [\"expression\"]\n",
195 | " }\n",
196 | " }\n",
197 | " }\n",
198 | "]\n",
199 | "\n",
200 | "# ChatMessages instance with system + user\n",
201 | "tool_call_messages = ChatMessages([\n",
202 | " ChatMessage(\n",
203 | " role=\"system\",\n",
204 | " content=\"You are a precise calculator. Always use the calculator tool for arithmetic and format your final answer as \\\\boxed{result}.\"\n",
205 | " ),\n",
206 | " ChatMessage(\n",
207 | " role=\"user\",\n",
208 | " content=\"What is 847 * 293 + 156?\"\n",
209 | " ),\n",
210 | "])"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": null,
216 | "id": "5be77c52",
217 | "metadata": {},
218 | "outputs": [],
219 | "source": [
220 | "# Use hierarchical tool voting\n",
221 | "scaling_alg_tool = SelfConsistency(tool_vote=\"tool_hierarchical\")\n",
222 | "\n",
223 | "budget = 5\n",
224 | "scaling_result = scaling_alg_tool.infer(\n",
225 | " lm, tool_call_messages, budget, return_response_only=False, tools=tools, tool_choice=\"auto\"\n",
226 | ")"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": null,
232 | "id": "ac526d23",
233 | "metadata": {
234 | "lines_to_next_cell": 2
235 | },
236 | "outputs": [],
237 | "source": [
238 | "print(\"######## Self-Consistency Result ########\")\n",
239 | "print(scaling_result.the_one)\n",
240 | "\n",
241 | "print(\"######## Tool Call Response Counts ########\")\n",
242 | "print(scaling_result.response_counts)"
243 | ]
244 | }
245 | ],
246 | "metadata": {
247 | "jupytext": {
248 | "cell_metadata_filter": "-all",
249 | "notebook_metadata_filter": "all"
250 | },
251 | "kernelspec": {
252 | "display_name": "inference_time_scaling-dev",
253 | "language": "python",
254 | "name": "python3"
255 | },
256 | "language_info": {
257 | "codemirror_mode": {
258 | "name": "ipython",
259 | "version": 3
260 | },
261 | "file_extension": ".py",
262 | "mimetype": "text/x-python",
263 | "name": "python",
264 | "nbconvert_exporter": "python",
265 | "pygments_lexer": "ipython3",
266 | "version": "3.11.11"
267 | }
268 | },
269 | "nbformat": 4,
270 | "nbformat_minor": 5
271 | }
272 |
--------------------------------------------------------------------------------
/docs/algorithms.md:
--------------------------------------------------------------------------------
1 | # Algorithms
2 |
3 | its-hub provides several inference-time scaling algorithms, each optimized for different use cases and computational budgets.
4 |
5 | ## Overview
6 |
7 | All algorithms follow the same interface: `infer(lm, prompt, budget, return_response_only=True)`
8 |
9 | The `budget` parameter controls computational resources allocated to each algorithm, with different interpretations:
10 |
11 | | Algorithm | Budget Interpretation | Snippet |
12 | |-----------|----------------------|---------|
13 | | Self-Consistency | Number of parallel generations | `SelfConsistency()` |
14 | | Best-of-N | Number of candidate responses | `BestOfN(rm)` |
15 | | Beam Search | Total generations ÷ beam width | `BeamSearch(sg, prm, beam_width=4)` |
16 | | Particle Filtering | Number of particles | `ParticleFiltering(sg, prm)` |
17 | | Entropic Particle Filtering | Number of particles | `EntropicParticleFiltering(sg, prm)` |
18 |
19 | ## Self-Consistency
20 |
21 | Generates multiple responses and selects the most common answer through voting. **Especially powerful for tool-calling** where you want consistent tool usage patterns.
22 |
23 | ### Tool Calling Example (Recommended)
24 |
25 | ```python
26 | from its_hub.algorithms import SelfConsistency
27 | from its_hub.types import ChatMessage, ChatMessages
28 | from its_hub.lms import OpenAICompatibleLanguageModel
29 |
30 | # Initialize language model
31 | lm = OpenAICompatibleLanguageModel(
32 | endpoint="https://api.openai.com/v1",
33 | api_key="your-api-key",
34 | model_name="gpt-4o-mini"
35 | )
36 |
37 | # Define tools (OpenAI format)
38 | tools = [
39 | {
40 | "type": "function",
41 | "function": {
42 | "name": "calculator",
43 | "description": "Perform arithmetic calculations",
44 | "parameters": {
45 | "type": "object",
46 | "properties": {
47 | "expression": {
48 | "type": "string",
49 | "description": "Mathematical expression to evaluate"
50 | }
51 | },
52 | "required": ["expression"]
53 | }
54 | }
55 | }
56 | ]
57 |
58 | # Create messages
59 | messages = ChatMessages([
60 | ChatMessage(
61 | role="system",
62 | content="You are a precise calculator. Always use the calculator tool for arithmetic."
63 | ),
64 | ChatMessage(
65 | role="user",
66 | content="What is 847 * 293 + 156?"
67 | )
68 | ])
69 |
70 | # Use hierarchical tool voting
71 | sc = SelfConsistency(tool_vote="tool_hierarchical")
72 | result = sc.infer(
73 | lm,
74 | messages,
75 | budget=5,
76 | tools=tools,
77 | tool_choice="auto"
78 | )
79 | print(result)
80 | ```
81 |
82 | **Tool voting modes:**
83 | - `"tool_name"`: Vote on which tool to call
84 | - `"tool_args"`: Vote on tool arguments
85 | - `"tool_hierarchical"` (recommended): First vote on tool name, then on arguments
86 | - `exclude_args=["timestamp", "id"]`: Exclude non-semantic arguments from voting
87 |
88 | ### Text-Based Example
89 |
90 | ```python
91 | # For mathematical problems with regex extraction
92 | def extract_boxed(text):
93 | import re
94 | matches = re.findall(r'\\boxed\{([^{}]+)\}', text)
95 | return matches[-1] if matches else ""
96 |
97 | sc = SelfConsistency(projection_function=extract_boxed)
98 | result = sc.infer(lm, "Solve x^2 + 5x + 6 = 0", budget=4)
99 | ```
100 |
101 | **When to use:**
102 | - Tool-calling applications (agents, function calling)
103 | - Mathematical problems with clear final answers
104 | - Tasks where multiple reasoning approaches are valid
105 | - When you need fast inference with improved accuracy
106 |
107 | ## Best-of-N
108 |
109 | Generates N candidate responses and selects the highest-scoring one using a reward model. **Works with both text and tool-calling responses.**
110 |
111 | ### With LLM Judge (Cloud APIs)
112 |
113 | ```python
114 | from its_hub.algorithms import BestOfN
115 | from its_hub.integration.reward_hub import LLMJudgeRewardModel
116 | from its_hub.lms import OpenAICompatibleLanguageModel
117 |
118 | # Initialize language model
119 | lm = OpenAICompatibleLanguageModel(
120 | endpoint="https://api.openai.com/v1",
121 | api_key="your-api-key",
122 | model_name="gpt-4o-mini"
123 | )
124 |
125 | # Set up LLM judge for scoring
126 | judge = LLMJudgeRewardModel(
127 | model="gpt-4o-mini",
128 | criterion="multi_step_tool_judge", # For tool-calling tasks
129 | judge_type="groupwise",
130 | api_key="your-api-key"
131 | )
132 |
133 | # Best-of-N with LLM judge
134 | bon = BestOfN(judge)
135 |
136 | # Works with tool calls
137 | tools = [{"type": "function", "function": {...}}]
138 | result = bon.infer(
139 | lm,
140 | messages,
141 | budget=4,
142 | tools=tools,
143 | tool_choice="auto"
144 | )
145 | ```
146 |
147 | ### With Local Process Reward Model
148 |
149 | ```python
150 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
151 |
152 | # Initialize reward model (requires GPU)
153 | prm = LocalVllmProcessRewardModel(
154 | model_name="Qwen/Qwen2.5-Math-PRM-7B",
155 | device="cuda:0",
156 | aggregation_method="prod"
157 | )
158 |
159 | bon = BestOfN(prm)
160 | result = bon.infer(lm, prompt, budget=16)
161 | ```
162 |
163 | **When to use:**
164 | - Tool-calling applications where quality matters most
165 | - When you have a reliable reward model
166 | - Quality is more important than speed
167 | - Tasks where ranking responses is straightforward
168 |
169 | ## Beam Search
170 |
171 | Performs step-by-step generation with beam width control, using process reward models to guide the search.
172 |
173 | ```python
174 | from its_hub.algorithms import BeamSearch
175 | from its_hub.lms import StepGeneration
176 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
177 |
178 | # Initialize components
179 | sg = StepGeneration("\n\n", max_steps=32, stop_pattern=r"\boxed")
180 | prm = LocalVllmProcessRewardModel(
181 | model_name="Qwen/Qwen2.5-Math-PRM-7B",
182 | device="cuda:0",
183 | aggregation_method="prod"
184 | )
185 |
186 | # Beam search with beam width of 4
187 | beam_search = BeamSearch(sg, prm, beam_width=4)
188 | result = beam_search.infer(lm, prompt, budget=32) # 32 total generations
189 | ```
190 |
191 | **Budget calculation:** `budget = beam_width × number_of_steps`
192 |
193 | **When to use:**
194 | - Step-by-step reasoning problems
195 | - When you can evaluate partial solutions
196 | - Mathematical proofs or derivations
197 |
198 | ## Particle Filtering
199 |
200 | Uses probabilistic resampling to maintain diverse reasoning paths while focusing on promising directions.
201 |
202 | ```python
203 | from its_hub.algorithms import ParticleFiltering
204 | from its_hub.lms import StepGeneration
205 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
206 |
207 | # Initialize components
208 | sg = StepGeneration("\n\n", max_steps=32, stop_pattern=r"\boxed")
209 | prm = LocalVllmProcessRewardModel(
210 | model_name="Qwen/Qwen2.5-Math-PRM-7B",
211 | device="cuda:0",
212 | aggregation_method="prod"
213 | )
214 |
215 | # Particle filtering with 8 particles
216 | pf = ParticleFiltering(sg, prm)
217 | result = pf.infer(lm, prompt, budget=8)
218 | ```
219 |
220 | **When to use:**
221 | - Complex reasoning tasks with multiple valid approaches
222 | - When exploration vs exploitation balance is important
223 | - Mathematical problem solving with uncertainty
224 |
225 |
226 | ## Entropic Particle Filtering
227 |
228 | Entropic Particle Filtering (ePF) is an advanced sampling algorithm that mitigates common failure modes in standard PF, like particle degeneracy and impoverishment.
229 | By leveraging Entropic Annealing (EA) to control the variance of the resampling distribution, ePF ensures a more robust and thorough exploration in the early phase of sampling, especially for complex long sequences and multi-step tasks.
230 |
231 | ```python
232 | from its_hub.algorithms import EntropicParticleFiltering
233 | from its_hub.lms import StepGeneration
234 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
235 |
236 | # Initialize components
237 | sg = StepGeneration("\n\n", max_steps=32, stop_pattern=r"\boxed")
238 | prm = LocalVllmProcessRewardModel(
239 | model_name="Qwen/Qwen2.5-Math-PRM-7B",
240 | device="cuda:0",
241 | aggregation_method="prod"
242 | )
243 |
244 | # Entropic particle filtering with 8 particles
245 | epf = EntropicParticleFiltering(sg, prm)
246 | result = epf.infer(lm, prompt, budget=8)
247 | ```
248 |
249 | **When to use:**
250 | - When Reward Models Are Hard to Calibrate:
251 | - If your Process Reward Model (PRM) tends to be *overconfident* early in the sampling process, ePF helps by keeping a wider range of options open for longer.
252 |
253 | - For Complex, Long Multi-Step Tasks:
254 | - When a problem requires many sequential steps to solve (> 20 steps), standard particle filters can lose diversity and generate greedy-like solutions. ePF is designed to handle these long-horizon tasks more effectively.
255 |
256 | - To Avoid Early Convergence:
257 | - If you notice that a standard filter is producing short, incomplete responses or underperforming, it is likely converging prematurely. ePF directly counteracts this by promoting particle diversity.
258 |
259 |
260 | ## Advanced Configuration
261 |
262 | ### Step Generation
263 |
264 | The `StepGeneration` class enables incremental text generation:
265 |
266 | ```python
267 | from its_hub.lms import StepGeneration
268 |
269 | # For math problems with boxed answers
270 | sg = StepGeneration(
271 | step_token="\n\n", # Split reasoning into steps
272 | max_steps=32, # Maximum number of steps
273 | stop_pattern=r"\boxed", # Stop when final answer is found
274 | post_process=True # Clean up output formatting
275 | )
276 | ```
277 |
278 | ### Reward Models
279 |
280 | #### Process Reward Models
281 | Evaluate reasoning steps incrementally:
282 |
283 | ```python
284 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
285 |
286 | prm = LocalVllmProcessRewardModel(
287 | model_name="Qwen/Qwen2.5-Math-PRM-7B",
288 | device="cuda:0",
289 | aggregation_method="prod" # or "mean", "min", "max"
290 | )
291 | ```
292 |
293 | #### Outcome Reward Models
294 | Evaluate final answers only:
295 |
296 | ```python
297 | # Custom outcome reward model
298 | class MathOutcomeRewardModel:
299 | def score(self, prompt, response):
300 | # Extract answer and compute reward
301 | return score
302 | ```
303 |
304 | ## Performance Tips
305 |
306 | 1. **Start with Self-Consistency** for quick improvements
307 | 2. **Use Best-of-N** when you have a good reward model
308 | 3. **Try Beam Search** for step-by-step reasoning
309 | 4. **Use Particle Filtering** for the most complex problems
310 | 5. **Use Entropic Particle Filtering** to mitigate early exploitation
311 | 6. **Adjust budget** based on problem complexity and time constraints
312 | 7. **Monitor GPU memory** when using large reward models
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/its_hub/integration/reward_hub.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 |
4 | from reward_hub.base import AggregationMethod
5 | from reward_hub.llm_judge import create_groupwise_judge, create_pointwise_judge
6 |
7 | from its_hub.base import AbstractOutcomeRewardModel, AbstractProcessRewardModel
8 | from its_hub.types import ChatMessage, ChatMessages
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class LocalVllmProcessRewardModel(AbstractProcessRewardModel):
14 | def __init__(
15 | self, model_name: str, device: str, aggregation_method: AggregationMethod
16 | ):
17 | from reward_hub.vllm.reward import VllmProcessRewardModel
18 |
19 | self.model = VllmProcessRewardModel(model_name=model_name, device=device)
20 | self.aggregation_method = aggregation_method
21 |
22 | async def ascore(
23 | self,
24 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
25 | response_or_responses: str | list[str],
26 | ) -> float | list[float]:
27 | """score response(s) asynchronously"""
28 | import asyncio
29 |
30 | chat_messages = ChatMessages.from_prompt_or_messages(prompt_or_messages)
31 |
32 | is_single_response = isinstance(response_or_responses, str)
33 | responses = (
34 | [response_or_responses] if is_single_response else response_or_responses
35 | )
36 |
37 | # Build conversation messages with responses
38 | base_msgs = [
39 | ChatMessage(role="user", content=f"System: {msg.extract_text_content()}")
40 | if msg.role == "system"
41 | else msg
42 | for msg in chat_messages.to_chat_messages()
43 | ]
44 | messages = [
45 | [
46 | *[
47 | {"role": msg.role, "content": msg.extract_text_content()}
48 | for msg in base_msgs
49 | ],
50 | {"role": "assistant", "content": response},
51 | ]
52 | for response in responses
53 | ]
54 |
55 | # Run in thread to avoid blocking event loop
56 | res = await asyncio.to_thread(
57 | self.model.score,
58 | messages=messages,
59 | aggregation_method=self.aggregation_method,
60 | return_full_prm_result=False,
61 | )
62 | return res[0] if is_single_response else res
63 |
64 | def score(
65 | self,
66 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
67 | response_or_responses: str | list[str],
68 | ) -> float | list[float]:
69 | """score response(s) synchronously"""
70 | import asyncio
71 |
72 | return asyncio.run(self.ascore(prompt_or_messages, response_or_responses))
73 |
74 |
75 | class LLMJudgeRewardModel(AbstractOutcomeRewardModel):
76 | """
77 | Adapter for reward_hub's LLM Judge models to work with its_hub's AbstractOutcomeRewardModel interface.
78 |
79 | This class wraps reward_hub's PointwiseJudgeModel to make it compatible with its_hub's
80 | prompt/response format and can be used with algorithms like Best-of-N.
81 | """
82 |
83 | def __init__(
84 | self,
85 | model: str,
86 | criterion: str,
87 | judge_type: str = "groupwise",
88 | api_key: str | None = None,
89 | base_url: str | None = None,
90 | temperature: float = 0.7,
91 | max_tokens: int = 4096,
92 | enable_judge_logging: bool = True,
93 | top_n: int = 1,
94 | **litellm_kwargs,
95 | ):
96 | """
97 | Initialize LLM Judge reward model.
98 |
99 | Args:
100 | model: LiteLLM model name (e.g., "gpt-4o-mini", "claude-3-sonnet-20240229")
101 | criterion: Evaluation criterion from CriterionRegistry (default: "overall_quality")
102 | Built-in options: overall_quality, writing_quality, technical_quality,
103 | relevance_quality, tool-judge
104 | judge_type: Type of judge - "pointwise" or "groupwise" (default: "groupwise")
105 | api_key: API key for the model provider
106 | base_url: Base URL for custom endpoints
107 | temperature: Temperature for judge generation (0.0 for deterministic)
108 | max_tokens: Maximum tokens for judge response
109 | enable_judge_logging: If True, log judge scores and reasoning (default: True)
110 | top_n: For groupwise judges, number of top responses to select (default: 1)
111 | **litellm_kwargs: Additional arguments passed to LiteLLM
112 | """
113 |
114 | if judge_type == "pointwise":
115 | self.judge = create_pointwise_judge(
116 | model=model,
117 | criterion=criterion,
118 | api_key=api_key,
119 | base_url=base_url,
120 | temperature=temperature,
121 | max_tokens=max_tokens,
122 | **litellm_kwargs,
123 | )
124 | elif judge_type == "groupwise":
125 | self.judge = create_groupwise_judge(
126 | model=model,
127 | criterion=criterion,
128 | api_key=api_key,
129 | base_url=base_url,
130 | temperature=temperature,
131 | max_tokens=max_tokens,
132 | **litellm_kwargs,
133 | )
134 | else:
135 | raise ValueError(
136 | f"Invalid judge type: {judge_type}. Must be 'pointwise' or 'groupwise'."
137 | )
138 |
139 | self.judge_type = judge_type
140 | self.criterion = criterion
141 | self.model = model
142 | self.top_n = top_n
143 | self.enable_judge_logging = enable_judge_logging
144 |
145 | def score(
146 | self,
147 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
148 | response: str | list[str],
149 | ) -> float | list[float]:
150 | """
151 | Score response(s) using the LLM judge.
152 |
153 | Args:
154 | prompt_or_messages: The prompt or conversation context
155 | response: The response(s) to evaluate (single string or list of strings)
156 |
157 | Returns:
158 | - For single response: float score
159 | - For multiple responses: list[float] scores
160 |
161 | Note:
162 | - If enable_judge_logging=True, the judge's reasoning is logged internally
163 | along with the scores from the JudgeResult object.
164 | - For pointwise judges: logs individual scores and reasoning per response
165 | - For groupwise judges: logs ranking reasoning and top-N selection with binary scores
166 | """
167 | # Use async version
168 | import asyncio
169 |
170 | return asyncio.run(self.ascore(prompt_or_messages, response))
171 |
172 | async def ascore(
173 | self,
174 | prompt_or_messages: str | list[ChatMessage] | ChatMessages,
175 | response: str | list[str],
176 | ) -> float | list[float]:
177 | """
178 | Score response(s) asynchronously using the LLM judge.
179 |
180 | Args:
181 | prompt_or_messages: The prompt or conversation context
182 | response: The response(s) to evaluate (single string or list of strings)
183 |
184 | Returns:
185 | - For single response: float score
186 | - For multiple responses: list[float] scores
187 |
188 | Note:
189 | - If enable_judge_logging=True, the judge's reasoning is logged internally
190 | along with the scores from the JudgeResult object.
191 | - For pointwise judges: logs individual scores and reasoning per response
192 | - For groupwise judges: logs ranking reasoning and top-N selection with binary scores
193 | """
194 | # Convert to ChatMessages format
195 | chat_messages = ChatMessages.from_prompt_or_messages(prompt_or_messages)
196 |
197 | # Build base conversation in OpenAI format
198 | base_messages = [
199 | {
200 | "role": msg.role,
201 | "content": msg.extract_text_content(),
202 | } # Reward Hub expects content to be a string
203 | for msg in chat_messages.to_chat_messages()
204 | ]
205 |
206 | # Handle both single response and batch of responses
207 | is_single_response = isinstance(response, str)
208 | responses = [response] if is_single_response else response
209 |
210 | # Build complete conversations (base + each response)
211 | conversations = [
212 | base_messages + [{"role": "assistant", "content": resp}]
213 | for resp in responses
214 | ]
215 |
216 | # Call judge with multiple conversations
217 | # Judge expects List[List[dict]] for multiple conversations
218 |
219 | if self.judge_type == "groupwise":
220 | judge_result = await self.judge.ascore(
221 | conversations,
222 | return_judge_reasoning=self.enable_judge_logging,
223 | top_n=self.top_n,
224 | )
225 | else:
226 | judge_result = await self.judge.ascore(
227 | conversations, return_judge_reasoning=self.enable_judge_logging
228 | )
229 |
230 | # Log judge results if enabled
231 | if self.enable_judge_logging and judge_result.reasonings:
232 | if self.judge_type == "pointwise":
233 | # Pointwise: log each response's individual score and reasoning
234 | for i, (score, reasoning, response) in enumerate(
235 | zip(judge_result.scores, judge_result.reasonings, responses)
236 | ):
237 | extra_data = {
238 | "judge_type": "pointwise",
239 | "response_index": i,
240 | "score": score,
241 | "reasoning": reasoning,
242 | "response_preview": response[:300] + "..."
243 | if len(response) > 300
244 | else response,
245 | "criterion": self.criterion,
246 | "model": self.model,
247 | }
248 | logger.info(
249 | f"Pointwise Judge Result for response {i}:\n{json.dumps(extra_data, indent=2)}"
250 | )
251 | else:
252 | # Groupwise: log ranking reasoning and which responses were selected as top-N
253 | # Binary scores: 1.0 for top-N, 0.0 for others
254 | top_indices = [
255 | i for i, score in enumerate(judge_result.scores) if score == 1.0
256 | ]
257 | response_previews = [
258 | {
259 | "index": i,
260 | "score": score,
261 | "preview": resp[:300] + "..." if len(resp) > 300 else resp,
262 | }
263 | for i, (score, resp) in enumerate(
264 | zip(judge_result.scores, responses)
265 | )
266 | ]
267 | extra_data = {
268 | "judge_type": "groupwise",
269 | "top_n": self.top_n,
270 | "top_indices": top_indices,
271 | "scores": judge_result.scores,
272 | "response_previews": response_previews,
273 | "ranking_reasoning": judge_result.reasonings[0]
274 | if judge_result.reasonings
275 | else None,
276 | "criterion": self.criterion,
277 | "model": self.model,
278 | }
279 | logger.info(
280 | f"Groupwise Judge Result: selected {len(top_indices)} of {len(responses)} responses\n{json.dumps(extra_data, indent=2, default=str)}"
281 | )
282 |
283 | # Return only scores (single float if single response, list otherwise)
284 | if is_single_response:
285 | return judge_result.scores[0]
286 | return judge_result.scores
287 |
--------------------------------------------------------------------------------
/docs/development.md:
--------------------------------------------------------------------------------
1 | # Development
2 |
3 | ## Getting Started
4 |
5 | ### Development Installation
6 |
7 | ```bash
8 | git clone https://github.com/Red-Hat-AI-Innovation-Team/its_hub.git
9 | cd its_hub
10 | pip install -e ".[dev]"
11 | ```
12 |
13 | The development installation includes:
14 | - All core dependencies
15 | - Testing frameworks (pytest, coverage)
16 | - Code formatting and linting tool (ruff)
17 | - Development tools and scripts
18 |
19 | ### Running Tests
20 |
21 | ```bash
22 | # Run all tests
23 | pytest tests
24 |
25 | # Run with coverage
26 | pytest tests --cov=its_hub
27 |
28 | # Run specific test modules
29 | pytest tests/test_algorithms.py
30 | pytest tests/test_lms.py
31 | pytest tests/test_iaas.py
32 | ```
33 |
34 | ### Code Quality
35 |
36 | ```bash
37 | # Run linter checks (Ruff configuration in pyproject.toml)
38 | ruff check its_hub/
39 |
40 | # Fix auto-fixable linting issues
41 | ruff check its_hub/ --fix
42 |
43 | # Format code
44 | ruff format its_hub/
45 | ```
46 |
47 | ## Architecture
48 |
49 | ### Core Design Principles
50 |
51 | **its-hub** follows a clean architecture with abstract base classes defining interfaces between components:
52 |
53 | 1. **Separation of Concerns**: Language models, algorithms, and reward models are independent
54 | 2. **Extensibility**: Easy to add new algorithms and models via abstract interfaces
55 | 3. **Async-First**: Built for high-performance concurrent inference
56 | 4. **Mathematical Focus**: Optimized for reasoning tasks with specialized prompts and evaluation
57 |
58 | ### Key Base Classes
59 |
60 | Located in `its_hub/base.py`:
61 |
62 | ```python
63 | # Language model interface
64 | class AbstractLanguageModel:
65 | def generate(self, prompt: str) -> str: ...
66 | def generate_batch(self, prompts: list[str]) -> list[str]: ...
67 |
68 | # Algorithm interface
69 | class AbstractScalingAlgorithm:
70 | def infer(self, lm, prompt, budget, return_response_only=True): ...
71 |
72 | # Result interface
73 | class AbstractScalingResult:
74 | @property
75 | def the_one(self) -> str: ... # Best response
76 |
77 | # Reward model interfaces
78 | class AbstractOutcomeRewardModel:
79 | def score(self, prompt: str, response: str) -> float: ...
80 |
81 | class AbstractProcessRewardModel:
82 | def score_steps(self, prompt: str, steps: list[str]) -> list[float]: ...
83 | ```
84 |
85 | ### Component Overview
86 |
87 | ```
88 | its_hub/
89 | ├── base.py # Abstract interfaces
90 | ├── lms.py # Language model implementations
91 | ├── algorithms/ # Scaling algorithms
92 | │ ├── self_consistency.py
93 | │ ├── bon.py
94 | │ ├── beam_search.py
95 | │ └── particle_gibbs.py
96 | ├── integration/ # External integrations
97 | │ ├── reward_hub.py # Reward model integration
98 | │ └── iaas.py # API server
99 | └── utils.py # Utilities and prompts
100 | ```
101 |
102 | ## Adding New Algorithms
103 |
104 | ### 1. Implement Abstract Interface
105 |
106 | ```python
107 | from its_hub.base import AbstractScalingAlgorithm, AbstractScalingResult
108 |
109 | class MyAlgorithmResult(AbstractScalingResult):
110 | def __init__(self, responses: list[str], scores: list[float]):
111 | self.responses = responses
112 | self.scores = scores
113 |
114 | @property
115 | def the_one(self) -> str:
116 | # Return best response based on your criteria
117 | best_idx = max(range(len(self.scores)), key=lambda i: self.scores[i])
118 | return self.responses[best_idx]
119 |
120 | class MyAlgorithm(AbstractScalingAlgorithm):
121 | def __init__(self, custom_param: float = 1.0):
122 | self.custom_param = custom_param
123 |
124 | def infer(self, lm, prompt: str, budget: int, return_response_only: bool = True):
125 | # Implement your algorithm logic here
126 | responses = []
127 | scores = []
128 |
129 | for i in range(budget):
130 | response = lm.generate(prompt)
131 | score = self._score_response(response)
132 | responses.append(response)
133 | scores.append(score)
134 |
135 | result = MyAlgorithmResult(responses, scores)
136 | return result.the_one if return_response_only else result
137 |
138 | def _score_response(self, response: str) -> float:
139 | # Implement your scoring logic
140 | return len(response) # Example: prefer longer responses
141 | ```
142 |
143 | ### 2. Add to Algorithms Module
144 |
145 | ```python
146 | # its_hub/algorithms/__init__.py
147 | from .my_algorithm import MyAlgorithm
148 |
149 | __all__ = ['SelfConsistency', 'BestOfN', 'BeamSearch', 'ParticleFiltering', 'MyAlgorithm']
150 | ```
151 |
152 | ### 3. Write Tests
153 |
154 | ```python
155 | # tests/test_my_algorithm.py
156 | import pytest
157 | from its_hub.algorithms import MyAlgorithm
158 | from its_hub.lms import OpenAICompatibleLanguageModel
159 |
160 | def test_my_algorithm():
161 | # Mock language model for testing
162 | class MockLM:
163 | def generate(self, prompt):
164 | return f"Response to: {prompt}"
165 |
166 | lm = MockLM()
167 | algorithm = MyAlgorithm(custom_param=2.0)
168 |
169 | result = algorithm.infer(lm, "test prompt", budget=3)
170 | assert isinstance(result, str)
171 | assert "Response to: test prompt" in result
172 | ```
173 |
174 | ## Adding New Language Models
175 |
176 | ### 1. Implement Abstract Interface
177 |
178 | ```python
179 | from its_hub.base import AbstractLanguageModel
180 |
181 | class MyLanguageModel(AbstractLanguageModel):
182 | def __init__(self, model_path: str):
183 | self.model_path = model_path
184 | # Initialize your model here
185 |
186 | def generate(self, prompt: str) -> str:
187 | # Implement single generation
188 | pass
189 |
190 | def generate_batch(self, prompts: list[str]) -> list[str]:
191 | # Implement batch generation
192 | return [self.generate(p) for p in prompts]
193 |
194 | def score(self, prompt: str, response: str) -> float:
195 | # Implement response scoring (optional)
196 | return 0.0
197 | ```
198 |
199 | ### 2. Add Async Support
200 |
201 | ```python
202 | import asyncio
203 | from typing import Optional
204 |
205 | class MyAsyncLanguageModel(AbstractLanguageModel):
206 | async def generate_async(self, prompt: str, **kwargs) -> str:
207 | # Implement async generation
208 | pass
209 |
210 | async def generate_batch_async(self, prompts: list[str], **kwargs) -> list[str]:
211 | # Implement async batch generation
212 | tasks = [self.generate_async(p, **kwargs) for p in prompts]
213 | return await asyncio.gather(*tasks)
214 | ```
215 |
216 | ## Adding New Reward Models
217 |
218 | ### Process Reward Model
219 |
220 | ```python
221 | from its_hub.base import AbstractProcessRewardModel
222 |
223 | class MyProcessRewardModel(AbstractProcessRewardModel):
224 | def __init__(self, model_path: str):
225 | self.model_path = model_path
226 | # Load your reward model
227 |
228 | def score_steps(self, prompt: str, steps: list[str]) -> list[float]:
229 | """Score each reasoning step"""
230 | scores = []
231 | context = prompt
232 |
233 | for step in steps:
234 | score = self._score_step(context, step)
235 | scores.append(score)
236 | context += f"\\n{step}"
237 |
238 | return scores
239 |
240 | def _score_step(self, context: str, step: str) -> float:
241 | # Implement step scoring logic
242 | return 1.0 # Placeholder
243 | ```
244 |
245 | ### Outcome Reward Model
246 |
247 | ```python
248 | from its_hub.base import AbstractOutcomeRewardModel
249 |
250 | class MyOutcomeRewardModel(AbstractOutcomeRewardModel):
251 | def score(self, prompt: str, response: str) -> float:
252 | """Score the final response"""
253 | # Implement outcome scoring logic
254 | return self._evaluate_correctness(prompt, response)
255 |
256 | def _evaluate_correctness(self, prompt: str, response: str) -> float:
257 | # Custom evaluation logic
258 | return 1.0 if "correct" in response.lower() else 0.0
259 | ```
260 |
261 | ## Testing Guidelines
262 |
263 | ### Unit Tests
264 |
265 | ```python
266 | # Test individual components
267 | def test_algorithm_basic():
268 | algorithm = MyAlgorithm()
269 | # Test basic functionality
270 |
271 | def test_algorithm_edge_cases():
272 | algorithm = MyAlgorithm()
273 | # Test edge cases and error conditions
274 |
275 | def test_algorithm_with_mock():
276 | # Use mocks to isolate component under test
277 | pass
278 | ```
279 |
280 | ### Integration Tests
281 |
282 | ```python
283 | # Test component interactions
284 | def test_algorithm_with_real_lm():
285 | lm = OpenAICompatibleLanguageModel(...)
286 | algorithm = MyAlgorithm()
287 | result = algorithm.infer(lm, "test", budget=2)
288 | # Verify end-to-end behavior
289 | ```
290 |
291 | ### Performance Tests
292 |
293 | ```python
294 | import time
295 |
296 | def test_algorithm_performance():
297 | start_time = time.time()
298 | # Run algorithm
299 | elapsed = time.time() - start_time
300 | assert elapsed < 10.0 # Performance requirement
301 | ```
302 |
303 | ## Git Workflow
304 |
305 | ### Commits
306 |
307 | Always use the sign-off flag for commits:
308 |
309 | ```bash
310 | git commit -s -m "feat: add new algorithm implementation"
311 | ```
312 |
313 | ### Branch Naming
314 |
315 | - `feat/algorithm-name` - New features
316 | - `fix/issue-description` - Bug fixes
317 | - `docs/section-name` - Documentation updates
318 | - `refactor/component-name` - Code refactoring
319 |
320 | ### Pull Request Process
321 |
322 | 1. Create feature branch from `main`
323 | 2. Make changes with signed commits
324 | 3. Add tests for new functionality
325 | 4. Update documentation as needed
326 | 5. Ensure all tests pass
327 | 6. Submit pull request with clear description
328 |
329 | ## Documentation
330 |
331 | ### Docstring Format
332 |
333 | Use Google-style docstrings:
334 |
335 | ```python
336 | def my_function(param1: str, param2: int = 10) -> bool:
337 | """Brief description of the function.
338 |
339 | Longer description if needed, explaining the purpose
340 | and any important details.
341 |
342 | Args:
343 | param1: Description of first parameter
344 | param2: Description of second parameter with default value
345 |
346 | Returns:
347 | Description of return value
348 |
349 | Raises:
350 | ValueError: Description of when this exception is raised
351 |
352 | Example:
353 | >>> result = my_function("hello", 5)
354 | >>> print(result)
355 | True
356 | """
357 | return len(param1) > param2
358 | ```
359 |
360 | ### Code Comments
361 |
362 | - Explain **why**, not **what**
363 | - Use comments for complex algorithms or non-obvious logic
364 | - Keep comments up-to-date with code changes
365 |
366 | ## Performance Optimization
367 |
368 | ### Profiling
369 |
370 | ```python
371 | import cProfile
372 | import pstats
373 |
374 | def profile_algorithm():
375 | pr = cProfile.Profile()
376 | pr.enable()
377 |
378 | # Run your algorithm here
379 | algorithm.infer(lm, prompt, budget=10)
380 |
381 | pr.disable()
382 | stats = pstats.Stats(pr)
383 | stats.sort_stats('cumulative').print_stats(10)
384 | ```
385 |
386 | ### Memory Optimization
387 |
388 | ```python
389 | import tracemalloc
390 |
391 | tracemalloc.start()
392 |
393 | # Run your code
394 | result = algorithm.infer(lm, prompt, budget=10)
395 |
396 | current, peak = tracemalloc.get_traced_memory()
397 | print(f"Current memory usage: {current / 1024 / 1024:.2f} MB")
398 | print(f"Peak memory usage: {peak / 1024 / 1024:.2f} MB")
399 | ```
400 |
401 | ### GPU Memory Management
402 |
403 | ```python
404 | import torch
405 |
406 | def optimize_gpu_memory():
407 | # Clear cache periodically
408 | torch.cuda.empty_cache()
409 |
410 | # Monitor memory usage
411 | allocated = torch.cuda.memory_allocated()
412 | cached = torch.cuda.memory_reserved()
413 | print(f"GPU Memory - Allocated: {allocated/1e9:.2f}GB, Cached: {cached/1e9:.2f}GB")
414 | ```
415 |
416 | ## Release Process
417 |
418 | ### Version Bumping
419 |
420 | Update version in `pyproject.toml`:
421 |
422 | ```toml
423 | [project]
424 | version = "0.2.0"
425 | ```
426 |
427 | ### Creating Releases
428 |
429 | 1. Update version number
430 | 2. Update CHANGELOG.md
431 | 3. Create git tag: `git tag -a v0.2.0 -m "Release v0.2.0"`
432 | 4. Push tag: `git push origin v0.2.0`
433 | 5. GitHub Actions will handle PyPI publishing
434 |
435 | ## Contributing
436 |
437 | ### Issues
438 |
439 | - Use issue templates when available
440 | - Provide minimal reproducible examples
441 | - Include environment details (OS, Python version, GPU type)
442 |
443 | ### Feature Requests
444 |
445 | - Explain the use case and motivation
446 | - Provide examples of desired API
447 | - Consider backwards compatibility
448 |
449 | ### Code Review
450 |
451 | - Review for correctness, performance, and maintainability
452 | - Suggest improvements constructively
453 | - Test the changes locally when possible
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Shared test configuration and fixtures."""
2 |
3 | import json
4 | import socket
5 | import threading
6 | import time
7 | from http.server import BaseHTTPRequestHandler, HTTPServer
8 |
9 | import pytest
10 | from fastapi.testclient import TestClient
11 |
12 | from its_hub.base import AbstractLanguageModel, AbstractOutcomeRewardModel
13 | from its_hub.integration.iaas import app
14 |
15 |
16 | def find_free_port() -> int:
17 | """Find a free port to use for test servers."""
18 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
19 | s.bind(("", 0))
20 | return s.getsockname()[1]
21 |
22 |
23 | class DummyVLLMHandler(BaseHTTPRequestHandler):
24 | """A dummy HTTP handler that mimics a vLLM server."""
25 |
26 | def do_POST(self):
27 | """Handle POST requests to the /v1/chat/completions endpoint."""
28 | if self.path == "/v1/chat/completions":
29 | content_length = int(self.headers["Content-Length"])
30 | post_data = self.rfile.read(content_length)
31 | request_data = json.loads(post_data.decode("utf-8"))
32 |
33 | # Simulate some processing time
34 | time.sleep(0.01)
35 |
36 | # Extract the user message
37 | messages = request_data.get("messages", [])
38 | user_content = messages[-1]["content"] if messages else "unknown"
39 |
40 | # Check for error triggers
41 | if "error" in user_content.lower():
42 | self.send_response(500)
43 | self.send_header("Content-Type", "application/json")
44 | self.end_headers()
45 | error_response = {
46 | "error": {
47 | "message": "Simulated vLLM error",
48 | "type": "server_error",
49 | "code": 500,
50 | }
51 | }
52 | self.wfile.write(json.dumps(error_response).encode("utf-8"))
53 | return
54 |
55 | # Create a response that includes the request content for testing
56 | response_content = f"vLLM response to: {user_content}"
57 |
58 | # Check if we should include stop tokens
59 | stop = request_data.get("stop")
60 | include_stop = request_data.get("include_stop_str_in_output", False)
61 |
62 | if stop and include_stop:
63 | response_content += stop
64 |
65 | # Create vLLM-like response
66 | response = {
67 | "id": "vllm-test-id",
68 | "object": "chat.completion",
69 | "created": int(time.time()),
70 | "model": request_data.get("model", "test-model"),
71 | "choices": [
72 | {
73 | "index": 0,
74 | "message": {"role": "assistant", "content": response_content},
75 | "finish_reason": "stop",
76 | }
77 | ],
78 | "usage": {
79 | "prompt_tokens": 10,
80 | "completion_tokens": 15,
81 | "total_tokens": 25,
82 | },
83 | }
84 |
85 | self.send_response(200)
86 | self.send_header("Content-Type", "application/json")
87 | self.end_headers()
88 | self.wfile.write(json.dumps(response).encode("utf-8"))
89 | else:
90 | self.send_response(404)
91 | self.end_headers()
92 | self.wfile.write(b"Not Found")
93 |
94 | def log_message(self, format, *args):
95 | """Suppress log messages to keep test output clean."""
96 | pass
97 |
98 |
99 | class DummyOpenAIHandler(BaseHTTPRequestHandler):
100 | """A dummy HTTP handler that mimics the OpenAI API."""
101 |
102 | # Class-level variables to track concurrent requests
103 | active_requests = 0
104 | max_concurrent_requests = 0
105 | request_lock = threading.Lock()
106 |
107 | @classmethod
108 | def reset_stats(cls):
109 | """Reset the request statistics."""
110 | with cls.request_lock:
111 | cls.active_requests = 0
112 | cls.max_concurrent_requests = 0
113 |
114 | def do_POST(self):
115 | """Handle POST requests to the /chat/completions endpoint."""
116 | if self.path == "/chat/completions":
117 | content_length = int(self.headers["Content-Length"])
118 | post_data = self.rfile.read(content_length)
119 | request_data = json.loads(post_data.decode("utf-8"))
120 |
121 | # Track concurrent requests
122 | with self.__class__.request_lock:
123 | self.__class__.active_requests += 1
124 | self.__class__.max_concurrent_requests = max(
125 | self.__class__.max_concurrent_requests,
126 | self.__class__.active_requests,
127 | )
128 |
129 | # Simulate some processing time
130 | time.sleep(0.1)
131 |
132 | # Check if we should simulate an error
133 | if "trigger_error" in request_data.get("messages", [{}])[-1].get(
134 | "content", ""
135 | ):
136 | self.send_response(500)
137 | self.send_header("Content-Type", "application/json")
138 | self.end_headers()
139 | error_response = {
140 | "error": {
141 | "message": "Simulated API error",
142 | "type": "server_error",
143 | "code": 500,
144 | }
145 | }
146 | self.wfile.write(json.dumps(error_response).encode("utf-8"))
147 |
148 | # Decrement active requests
149 | with self.__class__.request_lock:
150 | self.__class__.active_requests -= 1
151 |
152 | return
153 |
154 | # Extract the messages from the request
155 | messages = request_data.get("messages", [])
156 |
157 | # Prepare a response based on the messages
158 | response_content = f"Response to: {messages[-1]['content']}"
159 |
160 | # Check if there's a stop sequence and we should include it
161 | stop = request_data.get("stop")
162 | include_stop = request_data.get("include_stop_str_in_output", False)
163 |
164 | if stop and include_stop:
165 | response_content += stop
166 |
167 | # Create an OpenAI-like response
168 | response = {
169 | "id": "dummy-id",
170 | "object": "chat.completion",
171 | "created": 1234567890,
172 | "model": request_data.get("model", "dummy-model"),
173 | "choices": [
174 | {
175 | "index": 0,
176 | "message": {"role": "assistant", "content": response_content},
177 | "finish_reason": "stop",
178 | }
179 | ],
180 | }
181 |
182 | # Send the response
183 | self.send_response(200)
184 | self.send_header("Content-Type", "application/json")
185 | self.end_headers()
186 | self.wfile.write(json.dumps(response).encode("utf-8"))
187 |
188 | # Decrement active requests
189 | with self.__class__.request_lock:
190 | self.__class__.active_requests -= 1
191 | else:
192 | self.send_response(404)
193 | self.end_headers()
194 | self.wfile.write(b"Not Found")
195 |
196 | def log_message(self, format, *args):
197 | """Suppress log messages to keep test output clean."""
198 | pass
199 |
200 |
201 | class MockLanguageModel(AbstractLanguageModel):
202 | """Mock language model for testing."""
203 |
204 | def __init__(self, responses: list[str]):
205 | self.responses = responses
206 | self.call_count = 0
207 |
208 | def generate(
209 | self, messages, stop=None, temperature=None, include_stop_str_in_output=None
210 | ):
211 | if (
212 | isinstance(messages, list)
213 | and len(messages) > 0
214 | and isinstance(messages[0], list)
215 | ):
216 | # Batched generation - messages is List[List[ChatMessage]]
217 | num_requests = len(messages)
218 | if self.call_count + num_requests > len(self.responses):
219 | # Cycle through responses if we run out
220 | responses = []
221 | for i in range(num_requests):
222 | responses.append(
223 | self.responses[(self.call_count + i) % len(self.responses)]
224 | )
225 | else:
226 | responses = self.responses[
227 | self.call_count : self.call_count + num_requests
228 | ]
229 | self.call_count += num_requests
230 | return responses
231 | else:
232 | # Single generation - messages is List[ChatMessage]
233 | if self.call_count >= len(self.responses):
234 | # Cycle through responses if we run out
235 | response = self.responses[self.call_count % len(self.responses)]
236 | else:
237 | response = self.responses[self.call_count]
238 | self.call_count += 1
239 | return response
240 |
241 | def evaluate(self, prompt: str, generation: str) -> list[float]:
242 | return [0.1] * len(generation.split())
243 |
244 |
245 | class MockOutcomeRewardModel(AbstractOutcomeRewardModel):
246 | """Mock outcome reward model for testing."""
247 |
248 | def __init__(self, scores: list[float]):
249 | if isinstance(scores, float):
250 | self.scores = [scores]
251 | else:
252 | self.scores = scores
253 | self.call_count = 0
254 |
255 | def score(self, prompt: str, response) -> float:
256 | if isinstance(response, list):
257 | scores = self.scores[self.call_count : self.call_count + len(response)]
258 | self.call_count += len(response)
259 | return scores
260 | else:
261 | score = self.scores[self.call_count % len(self.scores)]
262 | self.call_count += 1
263 | return score
264 |
265 |
266 | class MockProcessRewardModel:
267 | """Mock process reward model for testing."""
268 |
269 | def __init__(self, scores: list[float]):
270 | if isinstance(scores[0], list):
271 | self.scores = [score for sublist in scores for score in sublist]
272 | else:
273 | self.scores = scores
274 | self.call_count = 0
275 |
276 | def score(self, prompt: str, response) -> float:
277 | if isinstance(response, list):
278 | scores = []
279 | for i in range(len(response)):
280 | scores.append(self.scores[(self.call_count + i) % len(self.scores)])
281 | self.call_count += len(response)
282 | return scores
283 | else:
284 | score = self.scores[self.call_count % len(self.scores)]
285 | self.call_count += 1
286 | return score
287 |
288 |
289 | # Pytest fixtures
290 |
291 |
292 | @pytest.fixture(scope="session")
293 | def vllm_server():
294 | """Start a vLLM mock server for the test session."""
295 | port = find_free_port()
296 | server = HTTPServer(("localhost", port), DummyVLLMHandler)
297 | server_thread = threading.Thread(target=server.serve_forever)
298 | server_thread.daemon = True
299 | server_thread.start()
300 |
301 | # Give the server a moment to start
302 | time.sleep(0.1)
303 |
304 | yield f"http://localhost:{port}"
305 |
306 | server.shutdown()
307 | server_thread.join()
308 |
309 |
310 | @pytest.fixture(scope="session")
311 | def openai_server():
312 | """Start an OpenAI mock server for the test session."""
313 | port = find_free_port()
314 | server = HTTPServer(("localhost", port), DummyOpenAIHandler)
315 | server_thread = threading.Thread(target=server.serve_forever)
316 | server_thread.daemon = True
317 | server_thread.start()
318 |
319 | # Give the server a moment to start
320 | time.sleep(0.1)
321 |
322 | yield f"http://localhost:{port}"
323 |
324 | server.shutdown()
325 | server_thread.join()
326 |
327 |
328 | @pytest.fixture
329 | def iaas_client():
330 | """Create a test client for the IaaS API."""
331 | # Reset global state before each test
332 | import its_hub.integration.iaas as iaas_module
333 |
334 | iaas_module.LM_DICT.clear()
335 | iaas_module.SCALING_ALG = None
336 |
337 | return TestClient(app)
338 |
339 |
340 | @pytest.fixture
341 | def mock_language_model():
342 | """Create a mock language model with default responses."""
343 | return MockLanguageModel(["response1", "response2", "response3"])
344 |
345 |
346 | @pytest.fixture
347 | def mock_outcome_reward_model():
348 | """Create a mock outcome reward model with default scores."""
349 | return MockOutcomeRewardModel([0.8, 0.6, 0.9])
350 |
351 |
352 | @pytest.fixture
353 | def mock_process_reward_model():
354 | """Create a mock process reward model with default scores."""
355 | return MockProcessRewardModel([0.7, 0.6, 0.8, 0.5])
356 |
357 |
358 | # Test constants
359 | TEST_CONSTANTS = {
360 | "DEFAULT_BUDGET": 4,
361 | "DEFAULT_TEMPERATURE": 0.7,
362 | "DEFAULT_MODEL_NAME": "test-model",
363 | "DEFAULT_API_KEY": "test-key",
364 | "DEFAULT_TIMEOUT": 0.1,
365 | "ERROR_TRIGGER": "trigger_error",
366 | "VLLM_ERROR_TRIGGER": "error",
367 | }
368 |
--------------------------------------------------------------------------------
/tests/test_reward_hub_integration.py:
--------------------------------------------------------------------------------
1 | """Integration tests for reward_hub integration."""
2 |
3 | from unittest.mock import MagicMock, patch
4 |
5 | import pytest
6 | from reward_hub.base import AggregationMethod
7 |
8 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel
9 |
10 |
11 | class TestLocalVllmProcessRewardModelIntegration:
12 | """Test the integration between its_hub and reward_hub."""
13 |
14 | @pytest.fixture
15 | def mock_vllm_model(self):
16 | """Create a mock VllmProcessRewardModel."""
17 | with patch("reward_hub.vllm.reward.VllmProcessRewardModel") as mock_class:
18 | mock_instance = MagicMock()
19 | mock_class.return_value = mock_instance
20 | yield mock_instance
21 |
22 | def test_single_response_scoring(self, mock_vllm_model):
23 | """Test scoring a single response with proper message format."""
24 | # Setup mock to return a single score
25 | mock_vllm_model.score.return_value = [0.85]
26 |
27 | # Create the reward model
28 | model = LocalVllmProcessRewardModel(
29 | model_name="test-model",
30 | device="cpu",
31 | aggregation_method=AggregationMethod.PRODUCT,
32 | )
33 |
34 | # Score a single response
35 | prompt = "What is 2+2?"
36 | response = "2+2 = 4"
37 | score = model.score(prompt, response)
38 |
39 | # Verify the score is returned correctly
40 | assert score == 0.85
41 |
42 | # Verify the mock was called with correct message format
43 | mock_vllm_model.score.assert_called_once()
44 | call_args = mock_vllm_model.score.call_args
45 |
46 | # Check that messages are in dict format, not ChatMessage objects
47 | messages = call_args[1]["messages"]
48 | assert len(messages) == 1 # Single conversation
49 | assert len(messages[0]) == 2 # User + assistant messages
50 |
51 | # Verify message format - should be dicts, not ChatMessage objects
52 | user_msg = messages[0][0]
53 | assistant_msg = messages[0][1]
54 |
55 | assert isinstance(user_msg, dict)
56 | assert isinstance(assistant_msg, dict)
57 | assert user_msg == {"role": "user", "content": prompt}
58 | assert assistant_msg == {"role": "assistant", "content": response}
59 |
60 | # Verify other parameters
61 | assert call_args[1]["aggregation_method"] == AggregationMethod.PRODUCT
62 | assert call_args[1]["return_full_prm_result"] is False
63 |
64 | def test_multiple_responses_scoring(self, mock_vllm_model):
65 | """Test scoring multiple responses with proper message format."""
66 | # Setup mock to return multiple scores
67 | mock_vllm_model.score.return_value = [0.85, 0.72, 0.91]
68 |
69 | # Create the reward model
70 | model = LocalVllmProcessRewardModel(
71 | model_name="test-model",
72 | device="cuda:0",
73 | aggregation_method=AggregationMethod.MIN,
74 | )
75 |
76 | # Score multiple responses
77 | prompt = "Solve this math problem: 3x + 5 = 14"
78 | responses = [
79 | "3x + 5 = 14\n3x = 9\nx = 3",
80 | "Let me solve step by step:\n3x = 14 - 5 = 9\nx = 3",
81 | "x = (14-5)/3 = 3",
82 | ]
83 | scores = model.score(prompt, responses)
84 |
85 | # Verify scores are returned correctly
86 | assert scores == [0.85, 0.72, 0.91]
87 |
88 | # Verify the mock was called with correct message format
89 | mock_vllm_model.score.assert_called_once()
90 | call_args = mock_vllm_model.score.call_args
91 |
92 | # Check that messages are in dict format for all responses
93 | messages = call_args[1]["messages"]
94 | assert len(messages) == 3 # Three conversations
95 |
96 | for i, conversation in enumerate(messages):
97 | assert len(conversation) == 2 # User + assistant messages
98 |
99 | user_msg = conversation[0]
100 | assistant_msg = conversation[1]
101 |
102 | # Verify message format - should be dicts, not ChatMessage objects
103 | assert isinstance(user_msg, dict)
104 | assert isinstance(assistant_msg, dict)
105 | assert user_msg == {"role": "user", "content": prompt}
106 | assert assistant_msg == {"role": "assistant", "content": responses[i]}
107 |
108 | # Verify other parameters
109 | assert call_args[1]["aggregation_method"] == AggregationMethod.MIN
110 | assert call_args[1]["return_full_prm_result"] is False
111 |
112 | def test_different_aggregation_methods(self, mock_vllm_model):
113 | """Test that different aggregation methods are passed correctly."""
114 | mock_vllm_model.score.return_value = [0.5]
115 |
116 | for agg_method in [
117 | AggregationMethod.PRODUCT,
118 | AggregationMethod.MIN,
119 | AggregationMethod.LAST,
120 | ]:
121 | model = LocalVllmProcessRewardModel(
122 | model_name="test-model", device="cpu", aggregation_method=agg_method
123 | )
124 |
125 | model.score("test prompt", "test response")
126 |
127 | # Check that the aggregation method was passed correctly
128 | call_args = mock_vllm_model.score.call_args
129 | assert call_args[1]["aggregation_method"] == agg_method
130 |
131 | def test_message_format_compatibility(self, mock_vllm_model):
132 | """Test that the message format is compatible with reward_hub expectations.
133 |
134 | This test specifically addresses the bug from issue #73 where ChatMessage
135 | objects were used instead of dict format.
136 | """
137 | mock_vllm_model.score.return_value = [0.7]
138 |
139 | model = LocalVllmProcessRewardModel(
140 | model_name="test-model",
141 | device="cpu",
142 | aggregation_method=AggregationMethod.PRODUCT,
143 | )
144 |
145 | # Score a response
146 | model.score("Test prompt", "Test response")
147 |
148 | # Get the messages that were passed to the reward_hub model
149 | call_args = mock_vllm_model.score.call_args
150 | messages = call_args[1]["messages"]
151 |
152 | # Verify that each message is a plain dict (not a ChatMessage object)
153 | for conversation in messages:
154 | for message in conversation:
155 | # Should be a dict with 'role' and 'content' keys
156 | assert isinstance(message, dict)
157 | assert "role" in message
158 | assert "content" in message
159 | assert len(message) == 2 # Only 'role' and 'content'
160 |
161 | # Should not have any class-specific attributes
162 | assert not hasattr(message, "__class__") or message.__class__ is dict
163 |
164 | # Role should be string
165 | assert isinstance(message["role"], str)
166 | assert message["role"] in ["user", "assistant"]
167 |
168 | # Content should be string
169 | assert isinstance(message["content"], str)
170 |
171 | def test_error_handling(self, mock_vllm_model):
172 | """Test that errors from reward_hub are properly propagated."""
173 | # Setup mock to raise an exception
174 | mock_vllm_model.score.side_effect = Exception("reward_hub error")
175 |
176 | model = LocalVllmProcessRewardModel(
177 | model_name="test-model",
178 | device="cpu",
179 | aggregation_method=AggregationMethod.PRODUCT,
180 | )
181 |
182 | # Verify that the exception is propagated
183 | with pytest.raises(Exception, match="reward_hub error"):
184 | model.score("test prompt", "test response")
185 |
186 | @pytest.mark.parametrize("device", ["cpu", "cuda:0", "cuda:1"])
187 | def test_device_parameter_passing(self, mock_vllm_model, device):
188 | """Test that device parameter is passed correctly to VllmProcessRewardModel."""
189 | with patch("reward_hub.vllm.reward.VllmProcessRewardModel") as mock_class:
190 | LocalVllmProcessRewardModel(
191 | model_name="test-model",
192 | device=device,
193 | aggregation_method=AggregationMethod.PRODUCT,
194 | )
195 |
196 | # Verify VllmProcessRewardModel was initialized with correct device
197 | mock_class.assert_called_once_with(model_name="test-model", device=device)
198 |
199 | def test_model_name_parameter_passing(self, mock_vllm_model):
200 | """Test that model_name parameter is passed correctly to VllmProcessRewardModel."""
201 | test_model_names = [
202 | "microsoft/DialoGPT-medium",
203 | "meta-llama/Llama-2-7b-chat-hf",
204 | "custom-model-name",
205 | ]
206 |
207 | for model_name in test_model_names:
208 | with patch("reward_hub.vllm.reward.VllmProcessRewardModel") as mock_class:
209 | LocalVllmProcessRewardModel(
210 | model_name=model_name,
211 | device="cpu",
212 | aggregation_method=AggregationMethod.PRODUCT,
213 | )
214 |
215 | # Verify VllmProcessRewardModel was initialized with correct model name
216 | mock_class.assert_called_once_with(model_name=model_name, device="cpu")
217 |
218 | def test_regression_chatmessage_format_bug(self, mock_vllm_model):
219 | """Regression test for issue #73: ChatMessage objects vs dict format.
220 |
221 | This test simulates what would happen if ChatMessage objects were used
222 | instead of dict format, which was the bug fixed in PR #73.
223 | """
224 |
225 | # Setup mock to be strict about message format
226 | def strict_score_check(messages, **kwargs):
227 | # This simulates reward_hub expecting dict format
228 | for conversation in messages:
229 | for message in conversation:
230 | # If this were a ChatMessage object, it would have additional attributes
231 | # and methods that are not expected by reward_hub
232 | if not isinstance(message, dict):
233 | raise TypeError(f"Expected dict, got {type(message)}")
234 |
235 | # Check that it only has the expected keys
236 | expected_keys = {"role", "content"}
237 | if set(message.keys()) != expected_keys:
238 | raise ValueError(
239 | f"Message has unexpected keys: {set(message.keys())}"
240 | )
241 |
242 | # Check that values are strings
243 | if not isinstance(message["role"], str):
244 | raise TypeError(
245 | f"Role should be string, got {type(message['role'])}"
246 | )
247 | if not isinstance(message["content"], str):
248 | raise TypeError(
249 | f"Content should be string, got {type(message['content'])}"
250 | )
251 |
252 | return [0.5]
253 |
254 | mock_vllm_model.score.side_effect = strict_score_check
255 |
256 | model = LocalVllmProcessRewardModel(
257 | model_name="test-model",
258 | device="cpu",
259 | aggregation_method=AggregationMethod.PRODUCT,
260 | )
261 |
262 | # This should work fine with the current implementation
263 | score = model.score("Test prompt", "Test response")
264 | assert score == 0.5
265 |
266 | # Verify that the format check passed (no exception was raised)
267 | mock_vllm_model.score.assert_called_once()
268 |
269 | def test_demonstrates_chatmessage_compatibility_issue(self):
270 | """Test that demonstrates what the issue #73 bug would look like.
271 |
272 | This is a demonstration test showing how ChatMessage objects would fail
273 | with reward_hub's expected dict format.
274 | """
275 | from dataclasses import dataclass
276 |
277 | # Simulate a ChatMessage-like object (what was causing the bug)
278 | @dataclass
279 | class ChatMessage:
280 | role: str
281 | content: str
282 |
283 | def to_dict(self):
284 | return {"role": self.role, "content": self.content}
285 |
286 | # Create messages in the old (broken) format
287 | user_msg = ChatMessage(role="user", content="What is 2+2?")
288 | assistant_msg = ChatMessage(role="assistant", content="2+2 = 4")
289 |
290 | # This would be what the old code might have done
291 | broken_messages = [[user_msg, assistant_msg]]
292 |
293 | # Simulate reward_hub's expectation (dict format)
294 | def check_dict_format(messages):
295 | for conversation in messages:
296 | for message in conversation:
297 | if not isinstance(message, dict):
298 | raise TypeError(f"Expected dict, got {type(message)}")
299 | if "role" not in message or "content" not in message:
300 | raise ValueError("Message missing required keys")
301 |
302 | # This would fail with the old implementation
303 | with pytest.raises(TypeError, match="Expected dict, got"):
304 | check_dict_format(broken_messages)
305 |
306 | # But this works with the current implementation (dict format)
307 | correct_messages = [
308 | [
309 | {"role": "user", "content": "What is 2+2?"},
310 | {"role": "assistant", "content": "2+2 = 4"},
311 | ]
312 | ]
313 |
314 | # This should pass
315 | check_dict_format(correct_messages) # No exception should be raised
316 |
--------------------------------------------------------------------------------