├── docs ├── .nojekyll ├── _navbar.md ├── _coverpage.md ├── _sidebar.md ├── index.html ├── README.md ├── installation.md ├── PLANNING_WRAPPER.md ├── quick-start.md ├── benchmarking.md ├── algorithms.md └── development.md ├── its_hub ├── integration │ ├── __init__.py │ └── reward_hub.py ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── bon.py │ └── beam_search.py ├── utils.py ├── base.py ├── types.py └── error_handling.py ├── tests ├── mocks │ ├── __init__.py │ ├── reward_models.py │ ├── test_data.py │ └── language_models.py ├── test_particle_gibbs_resampling.py ├── conftest.py └── test_reward_hub_integration.py ├── .claude └── settings.json ├── .jupytext.yml ├── ruff.toml ├── .github └── workflows │ ├── tests.yaml │ ├── sync-notebooks.yaml │ └── release.yaml ├── .devcontainer ├── devcontainer.json ├── Dockerfile └── init-firewall.sh ├── scripts └── test_math_example.py ├── pyproject.toml ├── README.md ├── .gitignore ├── notebooks ├── self-consistency.py └── self-consistency.ipynb ├── CLAUDE.md └── LICENSE /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /its_hub/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/mocks/__init__.py: -------------------------------------------------------------------------------- 1 | """Mock objects for testing.""" 2 | -------------------------------------------------------------------------------- /.claude/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "permissions": { 3 | "allow": [ 4 | "Bash(rg:*)" 5 | ], 6 | "deny": [] 7 | } 8 | } -------------------------------------------------------------------------------- /.jupytext.yml: -------------------------------------------------------------------------------- 1 | # Jupytext configuration 2 | formats: "py:percent,ipynb" 3 | notebook_metadata_filter: "all" 4 | cell_metadata_filter: "-all" -------------------------------------------------------------------------------- /its_hub/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Python library for inference-time scaling LLMs 3 | """ 4 | 5 | from importlib.metadata import version 6 | 7 | __version__ = version("its_hub") 8 | -------------------------------------------------------------------------------- /docs/_navbar.md: -------------------------------------------------------------------------------- 1 | - Links 2 | - [GitHub](https://github.com/Red-Hat-AI-Innovation-Team/its_hub) 3 | - [PyPI](https://pypi.org/project/its_hub/) 4 | - [Tests](https://github.com/Red-Hat-AI-Innovation-Team/its_hub/actions/workflows/tests.yml) 5 | - [Coverage](https://codecov.io/gh/Red-Hat-AI-Innovation-Team/its_hub) -------------------------------------------------------------------------------- /docs/_coverpage.md: -------------------------------------------------------------------------------- 1 | # its-hub 2 | 3 | > A Python library for inference-time scaling LLMs 4 | 5 | - 🔬 Multiple scaling algorithms (Particle Filtering, Best-of-N, Beam Search, Self-Consistency) 6 | - 🚀 OpenAI-compatible API with Inference-as-a-Service (IaaS) 7 | - ⚡ Async generation with concurrency limits and error handling 8 | - 📊 Comprehensive benchmarking tools 9 | 10 | [GitHub](https://github.com/Red-Hat-AI-Innovation-Team/its_hub) 11 | [Get Started](quick-start.md) 12 | -------------------------------------------------------------------------------- /docs/_sidebar.md: -------------------------------------------------------------------------------- 1 | - [Overview](README.md) 2 | - [Installation](installation.md) 3 | - [Quick Start Guide](quick-start.md) 4 | - [IaaS Service Guide](iaas-service.md) 5 | - [Algorithms](algorithms.md) 6 | - [Particle Filtering](algorithms.md#particle-filtering) 7 | - [Best-of-N](algorithms.md#best-of-n) 8 | - [Beam Search](algorithms.md#beam-search) 9 | - [Self-Consistency](algorithms.md#self-consistency) 10 | - [Benchmarking](benchmarking.md) 11 | - [Development](development.md) -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 88 2 | target-version = "py311" 3 | extend-exclude = ["_version.py"] 4 | 5 | [lint] 6 | select = [ 7 | "E", # pycodestyle errors 8 | "W", # pycodestyle warnings 9 | "F", # pyflakes 10 | "I", # isort 11 | "N", # pep8-naming 12 | "UP", # pyupgrade 13 | "B", # flake8-bugbear 14 | "C4", # flake8-comprehensions 15 | "SIM", # flake8-simplify 16 | "TID", # flake8-tidy-imports 17 | "RUF", # Ruff-specific rules 18 | ] 19 | ignore = [ 20 | "E501", # line too long (handled by formatter) 21 | "B008", # do not perform function calls in argument defaults 22 | "B905", # `zip()` without an explicit `strict=` parameter 23 | ] 24 | 25 | [lint.per-file-ignores] 26 | "__init__.py" = ["F401"] # unused imports in __init__ files 27 | "tests/*" = ["S101"] # use of assert in tests 28 | 29 | [lint.isort] 30 | known-first-party = ["its_hub"] 31 | 32 | [format] 33 | quote-style = "double" 34 | indent-style = "space" 35 | skip-magic-trailing-comma = false 36 | line-ending = "auto" -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | code-quality: 11 | runs-on: ubuntu-latest 12 | continue-on-error: true 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Install uv 17 | uses: astral-sh/setup-uv@v4 18 | with: 19 | version: "latest" 20 | 21 | - name: Set up Python 22 | run: uv python install 3.11 23 | 24 | - name: Install dependencies 25 | run: uv sync --extra dev 26 | 27 | - name: Run linting checks 28 | run: uv run ruff check its_hub/ 29 | 30 | - name: Run formatting checks 31 | run: uv run ruff format --check its_hub/ 32 | 33 | test: 34 | runs-on: ubuntu-latest 35 | strategy: 36 | matrix: 37 | python-version: ["3.10", "3.11", "3.12"] 38 | 39 | steps: 40 | - uses: actions/checkout@v4 41 | 42 | - name: Install uv 43 | uses: astral-sh/setup-uv@v4 44 | with: 45 | version: "latest" 46 | 47 | - name: Set up Python 48 | run: uv python install ${{ matrix.python-version }} 49 | 50 | - name: Install dependencies 51 | run: uv sync --extra dev 52 | 53 | - name: Run tests 54 | run: uv run pytest tests/ --cov=its_hub --cov-report=xml 55 | 56 | - name: Upload coverage to Codecov 57 | uses: codecov/codecov-action@v4 58 | with: 59 | token: ${{ secrets.CODECOV_TOKEN }} 60 | file: ./coverage.xml 61 | fail_ci_if_error: false 62 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Claude Code Sandbox", 3 | "build": { 4 | "dockerfile": "Dockerfile", 5 | "args": { 6 | "TZ": "${localEnv:TZ:America/New_York}" 7 | } 8 | }, 9 | "runArgs": [ 10 | "--cap-add=NET_ADMIN", 11 | "--cap-add=NET_RAW" 12 | ], 13 | "customizations": { 14 | "vscode": { 15 | "extensions": [ 16 | "dbaeumer.vscode-eslint", 17 | "esbenp.prettier-vscode", 18 | "eamodio.gitlens", 19 | "ms-python.python" 20 | ], 21 | "settings": { 22 | "editor.formatOnSave": true, 23 | "editor.defaultFormatter": "esbenp.prettier-vscode", 24 | "editor.codeActionsOnSave": { 25 | "source.fixAll.eslint": "explicit" 26 | }, 27 | "terminal.integrated.defaultProfile.linux": "zsh", 28 | "terminal.integrated.profiles.linux": { 29 | "bash": { 30 | "path": "bash", 31 | "icon": "terminal-bash" 32 | }, 33 | "zsh": { 34 | "path": "zsh" 35 | } 36 | } 37 | } 38 | } 39 | }, 40 | "remoteUser": "node", 41 | "mounts": [ 42 | "source=claude-code-bashhistory,target=/commandhistory,type=volume", 43 | "source=claude-code-config,target=/home/node/.claude,type=volume" 44 | ], 45 | "remoteEnv": { 46 | "NODE_OPTIONS": "--max-old-space-size=4096", 47 | "CLAUDE_CONFIG_DIR": "/home/node/.claude", 48 | "POWERLEVEL9K_DISABLE_GITSTATUS": "true" 49 | }, 50 | "workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated", 51 | "workspaceFolder": "/workspace", 52 | "postCreateCommand": "sudo /usr/local/bin/init-firewall.sh" 53 | } 54 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | its_hub - Inference-Time Scaling for LLMs 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /.github/workflows/sync-notebooks.yaml: -------------------------------------------------------------------------------- 1 | name: Sync Notebooks 2 | 3 | on: 4 | pull_request: 5 | branches: [ main ] 6 | paths: [ 'notebooks/**/*.py', 'notebooks/**/*.ipynb' ] 7 | 8 | jobs: 9 | sync-notebooks: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | contents: write 13 | 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | token: ${{ secrets.GITHUB_TOKEN }} 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: '3.11' 25 | 26 | - name: Install dependencies 27 | run: | 28 | pip install jupytext jupyter 29 | 30 | - name: Sync Python files to notebooks 31 | run: | 32 | # Convert .py files to .ipynb 33 | find . -name "*.py" -path "*/notebooks/*" -exec jupytext --to ipynb {} \; 34 | 35 | - name: Sync notebooks to Python files 36 | run: | 37 | # Convert .ipynb files to .py (in case someone committed a notebook directly) 38 | find . -name "*.ipynb" -path "*/notebooks/*" -exec jupytext --to py:percent {} \; 39 | 40 | - name: Check for changes 41 | id: verify-changed-files 42 | run: | 43 | if [ -n "$(git status --porcelain)" ]; then 44 | echo "changed=true" >> $GITHUB_OUTPUT 45 | else 46 | echo "changed=false" >> $GITHUB_OUTPUT 47 | fi 48 | 49 | - name: Commit and push changes to PR branch 50 | if: steps.verify-changed-files.outputs.changed == 'true' 51 | run: | 52 | git config --local user.email "action@github.com" 53 | git config --local user.name "GitHub Action" 54 | git add . 55 | git commit -m "Auto-sync notebooks and Python files" || exit 0 56 | git push origin HEAD:${{ github.head_ref }} -------------------------------------------------------------------------------- /its_hub/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch, BeamSearchResult 2 | from .bon import BestOfN, BestOfNResult 3 | from .particle_gibbs import ( 4 | EntropicParticleFiltering, 5 | ParticleFiltering, 6 | ParticleFilteringResult, 7 | ParticleGibbs, 8 | ParticleGibbsResult, 9 | ) 10 | from .self_consistency import SelfConsistency, SelfConsistencyResult 11 | 12 | __all__ = [ 13 | "BeamSearch", 14 | "BeamSearchResult", 15 | "BestOfN", 16 | "BestOfNResult", 17 | "EntropicParticleFiltering", 18 | "MetropolisHastings", 19 | "MetropolisHastingsResult", 20 | "ParticleFiltering", 21 | "ParticleFilteringResult", 22 | "ParticleGibbs", 23 | "ParticleGibbsResult", 24 | "SelfConsistency", 25 | "SelfConsistencyResult", 26 | ] 27 | 28 | ### 29 | 30 | from typing import Union 31 | 32 | from its_hub.base import ( 33 | AbstractLanguageModel, 34 | AbstractOutcomeRewardModel, 35 | AbstractScalingAlgorithm, 36 | AbstractScalingResult, 37 | ) 38 | from its_hub.lms import StepGeneration 39 | from its_hub.types import ChatMessage, ChatMessages 40 | 41 | 42 | class MetropolisHastingsResult(AbstractScalingResult): 43 | pass 44 | 45 | 46 | class MetropolisHastings(AbstractScalingAlgorithm): 47 | def __init__( 48 | self, step_generation: StepGeneration, orm: AbstractOutcomeRewardModel 49 | ): 50 | self.step_generation = step_generation 51 | self.orm = orm 52 | 53 | def infer( 54 | self, 55 | lm: AbstractLanguageModel, 56 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 57 | budget: int, 58 | show_progress: bool = False, 59 | return_response_only: bool = True, 60 | ) -> str | MetropolisHastingsResult: 61 | # TODO: Implement Metropolis-Hastings algorithm 62 | # Will need to convert prompt_or_messages to ChatMessages format when implemented 63 | raise NotImplementedError("Metropolis-Hastings algorithm not yet implemented") 64 | -------------------------------------------------------------------------------- /scripts/test_math_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Example script demonstrating the use of its_hub for math problem solving. 4 | This script tests the Qwen math model with various mathematical problems 5 | using particle filtering for improved solution quality. 6 | """ 7 | 8 | import os 9 | 10 | from its_hub.algorithms import ParticleFiltering 11 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 12 | from its_hub.lms import OpenAICompatibleLanguageModel, StepGeneration 13 | from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT 14 | 15 | 16 | def main(): 17 | # Get GPU ID from environment variable or default to 0 18 | gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0") 19 | 20 | # Initialize the language model 21 | # Note: The endpoint port (8100) must match the port used when starting the vLLM server 22 | lm = OpenAICompatibleLanguageModel( 23 | endpoint="http://localhost:8100/v1", # Make sure this matches your vLLM server port 24 | api_key="NO_API_KEY", 25 | model_name="Qwen/Qwen2.5-Math-1.5B-Instruct", 26 | system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT, 27 | ) 28 | 29 | # Test prompts 30 | test_prompts = [ 31 | "What is 2+2? Show your steps.", 32 | "Solve the quadratic equation x^2 + 5x + 6 = 0. Show your steps.", 33 | "Find the derivative of f(x) = x^2 + 3x + 2. Show your steps.", 34 | "Let a be a positive real number such that all the roots of x^3 + ax^2 + ax + 1 = 0 are real. Find the smallest possible value of a.", 35 | ] 36 | 37 | # Initialize step generation and reward model 38 | sg = StepGeneration(step_token="\n\n", max_steps=32, stop_token=r"\boxed") 39 | prm = LocalVllmProcessRewardModel( 40 | model_name="Qwen/Qwen2.5-Math-PRM-7B", 41 | device=f"cuda:{gpu_id}", # Use the same GPU as the vLLM server 42 | aggregation_method="prod", 43 | ) 44 | scaling_alg = ParticleFiltering(sg, prm) 45 | 46 | # Run tests 47 | print("Testing Qwen Math Model with different approaches...") 48 | print(f"Using GPU {gpu_id} with memory optimization settings\n") 49 | 50 | for prompt in test_prompts: 51 | print(f"\nTesting: {prompt}") 52 | print("Response:", scaling_alg.infer(lm, prompt, budget=8)) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # its-hub 2 | 3 | [![Tests](https://github.com/Red-Hat-AI-Innovation-Team/its_hub/actions/workflows/tests.yml/badge.svg)](https://github.com/Red-Hat-AI-Innovation-Team/its_hub/actions/workflows/tests.yml) 4 | [![codecov](https://codecov.io/gh/Red-Hat-AI-Innovation-Team/its_hub/graph/badge.svg?token=6WD8NB9YPN)](https://codecov.io/gh/Red-Hat-AI-Innovation-Team/its_hub) 5 | 6 | **its-hub** provides inference-time scaling for LLMs through multiple approaches: 7 | 8 | 1. **Direct Library Usage** - For Python integration 9 | 2. **Inference-as-a-Service (IaaS) API** - OpenAI-compatible HTTP API (⚠️ Alpha) 10 | 11 | ## What is Inference-Time Scaling? 12 | 13 | Inference-time scaling improves LLM performance by using computational resources during inference to generate better responses. Unlike training-time scaling which requires more parameters or training data, inference-time scaling algorithms can improve any pre-trained model's performance by: 14 | 15 | - **Generating multiple candidate responses** and selecting the best one 16 | - **Using step-by-step reasoning** with reward models to guide generation 17 | - **Applying probabilistic methods** like particle filtering for better exploration 18 | 19 | ## Key Features 20 | 21 | - 🔬 **Multiple Algorithms**: Particle Filtering, Best-of-N, Beam Search, Self-Consistency 22 | - 🚀 **OpenAI-Compatible API**: Easy integration with existing applications 23 | - 🧮 **Math-Optimized**: Built for mathematical reasoning with specialized prompts and evaluation 24 | - 📊 **Benchmarking Tools**: Compare algorithms on standard datasets like MATH500 and AIME-2024 25 | - ⚡ **Async Support**: Concurrent generation with limits and error handling 26 | 27 | ## Supported Algorithms 28 | 29 | | Algorithm | Budget Interpretation | Snippet | 30 | |-----------|----------------------|---------| 31 | | **Self-Consistency** | Number of parallel generations | `SelfConsistency()` | 32 | | **Best-of-N** | Number of candidates to generate | `BestOfN(rm)` | 33 | | **Beam Search** | Total generations ÷ beam width | `BeamSearch(sg, prm, beam_width=4)` | 34 | | **Particle Filtering** | Number of particles to maintain | `ParticleFiltering(sg, prm)` | 35 | | **Planning Enhancement** | Enhances any algorithm with planning | `PlanningWrapper(base_algorithm)` | 36 | 37 | ### Planning Enhancement 38 | 39 | The **PlanningWrapper** can enhance any ITS algorithm with a planning phase that generates multiple solution approaches before execution. See [PLANNING_WRAPPER.md](PLANNING_WRAPPER.md) for detailed documentation. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | [build-system] 4 | requires = ["setuptools>=42", "wheel", "setuptools_scm>=8.0.0"] 5 | build-backend = "setuptools.build_meta" 6 | 7 | [project] 8 | name = "its_hub" 9 | description = "A Python library for inference-time scaling LLMs" 10 | authors = [ 11 | {name = "Kai Xu and the Red Hat AI Innovation Team", email = "xuk@redhat.com"} 12 | ] 13 | readme = "README.md" 14 | requires-python = ">=3.10" 15 | license = "Apache-2.0" 16 | classifiers = [ 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3.11", 19 | "Programming Language :: Python :: 3.12", 20 | "Operating System :: OS Independent", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | ] 23 | dynamic = ["version"] 24 | dependencies = [ 25 | "openai>=1.68.2", 26 | "tqdm>=4.65.0", 27 | "typing-extensions>=4.12.2", 28 | "reward-hub>=0.1.8", 29 | "transformers>=4.53.2", 30 | "backoff>=2.2.0", 31 | "click>=8.1.0", 32 | "fastapi>=0.115.5", 33 | "pydantic>=2.7.2", 34 | "numpy", 35 | "uvicorn", 36 | "requests", 37 | "aiohttp>=3.9.0", 38 | "litellm>=1.70.0,<1.75.0" 39 | ] 40 | 41 | [project.scripts] 42 | its-iaas = "its_hub.integration.iaas:main" 43 | 44 | [tool.setuptools_scm] 45 | version_file = "its_hub/_version.py" 46 | # do not include +gREV local version, required for Test PyPI upload 47 | local_scheme = "no-local-version" 48 | 49 | [project.urls] 50 | Homepage = "https://ai-innovation.team/its_hub" 51 | 52 | [project.optional-dependencies] 53 | 54 | vllm = [ 55 | "reward-hub[vllm]>=0.1.8", 56 | ] 57 | 58 | dev = [ 59 | "pytest>=7.0.0", 60 | "pytest-asyncio>=0.21.0", 61 | "pytest-cov>=4.1.0", 62 | "ruff>=0.10.0", 63 | "jupytext>=1.15.0", 64 | "jupyter>=1.0.0", 65 | "reward-hub[vllm]>=0.1.8", 66 | ] 67 | 68 | prm = [ 69 | "reward-hub[prm]>=0.1.8" 70 | ] 71 | 72 | research = [ 73 | "math-verify>=0.1.0", # For mathematical reasoning evaluation in benchmark scripts 74 | "datasets>=2.0.0", # For loading benchmark datasets (MATH500, AIME) 75 | "matplotlib>=3.5.0", # For visualization scripts 76 | ] 77 | cloud = [ 78 | "boto3>=1.28.0", # For AWS Bedrock support 79 | "google-cloud-aiplatform>=1.38.0", # For Vertex AI support 80 | ] 81 | 82 | [tool.setuptools] 83 | package-dir = {"" = "."} 84 | 85 | [tool.setuptools.packages.find] 86 | where = ["."] 87 | include = [ 88 | "its_hub", 89 | "its_hub.algorithms", 90 | "its_hub.integration", 91 | ] 92 | 93 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM node:20 2 | 3 | ARG TZ 4 | ENV TZ="$TZ" 5 | 6 | # Install basic development tools and iptables/ipset 7 | RUN apt update && apt install -y less \ 8 | git \ 9 | procps \ 10 | sudo \ 11 | fzf \ 12 | zsh \ 13 | man-db \ 14 | unzip \ 15 | gnupg2 \ 16 | gh \ 17 | iptables \ 18 | ipset \ 19 | iproute2 \ 20 | dnsutils \ 21 | aggregate \ 22 | jq \ 23 | curl 24 | 25 | # Ensure default node user has access to /usr/local/share 26 | RUN mkdir -p /usr/local/share/npm-global && \ 27 | chown -R node:node /usr/local/share 28 | 29 | ARG USERNAME=node 30 | 31 | # Persist bash history. 32 | RUN SNIPPET="export PROMPT_COMMAND='history -a' && export HISTFILE=/commandhistory/.bash_history" \ 33 | && mkdir /commandhistory \ 34 | && touch /commandhistory/.bash_history \ 35 | && chown -R $USERNAME /commandhistory 36 | 37 | # Set `DEVCONTAINER` environment variable to help with orientation 38 | ENV DEVCONTAINER=true 39 | 40 | # Create workspace and config directories and set permissions 41 | RUN mkdir -p /workspace /home/node/.claude && \ 42 | chown -R node:node /workspace /home/node/.claude 43 | 44 | WORKDIR /workspace 45 | 46 | RUN ARCH=$(dpkg --print-architecture) && \ 47 | wget "https://github.com/dandavison/delta/releases/download/0.18.2/git-delta_0.18.2_${ARCH}.deb" && \ 48 | sudo dpkg -i "git-delta_0.18.2_${ARCH}.deb" && \ 49 | rm "git-delta_0.18.2_${ARCH}.deb" 50 | 51 | # Set up non-root user 52 | USER node 53 | 54 | # Install global packages 55 | ENV NPM_CONFIG_PREFIX=/usr/local/share/npm-global 56 | ENV PATH=$PATH:/usr/local/share/npm-global/bin 57 | 58 | # Set the default shell to zsh rather than sh 59 | ENV SHELL=/bin/zsh 60 | 61 | # Default powerline10k theme 62 | RUN sh -c "$(wget -O- https://github.com/deluan/zsh-in-docker/releases/download/v1.2.0/zsh-in-docker.sh)" -- \ 63 | -p git \ 64 | -p fzf \ 65 | -a "source /usr/share/doc/fzf/examples/key-bindings.zsh" \ 66 | -a "source /usr/share/doc/fzf/examples/completion.zsh" \ 67 | -a "export PROMPT_COMMAND='history -a' && export HISTFILE=/commandhistory/.bash_history" \ 68 | -x 69 | 70 | # Install uv for Python package management 71 | RUN curl -LsSf https://astral.sh/uv/install.sh | sh 72 | ENV PATH="/home/node/.cargo/bin:$PATH" 73 | 74 | # Install Claude 75 | RUN npm install -g @anthropic-ai/claude-code 76 | 77 | # Copy and set up firewall script 78 | COPY init-firewall.sh /usr/local/bin/ 79 | USER root 80 | RUN chmod +x /usr/local/bin/init-firewall.sh && \ 81 | echo "node ALL=(root) NOPASSWD: /usr/local/bin/init-firewall.sh" > /etc/sudoers.d/node-firewall && \ 82 | chmod 0440 /etc/sudoers.d/node-firewall 83 | USER node 84 | -------------------------------------------------------------------------------- /its_hub/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | # the system prompt for step-by-step reasoning taken from https://github.com/huggingface/search-and-learn 4 | SAL_STEP_BY_STEP_SYSTEM_PROMPT = "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n## Step 1: [Concise description]\n[Brief explanation and calculations]\n\n## Step 2: [Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem." 5 | 6 | QWEN_SYSTEM_PROMPT = ( 7 | "Please reason step by step, and put your final answer within \\boxed{}." 8 | ) 9 | 10 | 11 | def extract_content_from_lm_response(message: dict) -> str: 12 | """ 13 | Extract content from a single LM response message object. 14 | 15 | Args: 16 | message: A message dict returned by fetch_single_response. 17 | 18 | Returns: 19 | The content string. If the message contains tool calls, returns the content 20 | if available, otherwise returns an empty string. 21 | """ 22 | # TODO: This conversion to text is not ideal as it involves manually formatting 23 | # tool calls and neglects images in multi-modal content. Consider refactoring 24 | # to work with structured message objects instead of flattening to strings. 25 | 26 | # Extract text content (handle both string and list[dict] formats) 27 | raw_content = message.get("content") 28 | 29 | if raw_content is None: 30 | content = "" 31 | elif isinstance(raw_content, str): 32 | content = raw_content 33 | elif isinstance(raw_content, list): 34 | # Multi-modal content: extract text parts (images are ignored) 35 | text_parts = [ 36 | item.get("text", "") 37 | for item in raw_content 38 | if isinstance(item, dict) and item.get("type") == "text" 39 | ] 40 | content = " ".join(text_parts) 41 | else: 42 | raise ValueError( 43 | f"Invalid content type: {type(raw_content)}, expected str, list[dict], or None" 44 | ) 45 | 46 | # If there are tool calls, add tool-calls to the content 47 | if message.get("tool_calls"): 48 | tool_calls = message.get("tool_calls", []) 49 | tool_descriptions = [] 50 | for tc in tool_calls: 51 | if isinstance(tc, dict) and "function" in tc: 52 | func = tc["function"] 53 | func_name = func.get("name", "unknown") 54 | tool_descriptions.append( 55 | f"[Tool call: {func_name} Tool args: {json.dumps(func.get('arguments', {}))}]" 56 | ) 57 | else: 58 | raise ValueError( 59 | f"Invalid tool call: {tc}, expected a dict with a 'function' key" 60 | ) 61 | content += " ".join(tool_descriptions) 62 | 63 | return content 64 | -------------------------------------------------------------------------------- /its_hub/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from .types import ChatMessage, ChatMessages 4 | 5 | 6 | class AbstractLanguageModel(ABC): 7 | """abstract base class for (autoregressive) language models""" 8 | 9 | @abstractmethod 10 | async def agenerate( 11 | self, 12 | messages: list[ChatMessage] | list[list[ChatMessage]], 13 | stop: str | None = None, 14 | ) -> str | list[str]: 15 | """generate a response from the model asynchronously""" 16 | pass 17 | 18 | @abstractmethod 19 | def generate( 20 | self, 21 | messages: list[ChatMessage] | list[list[ChatMessage]], 22 | stop: str | None = None, 23 | ) -> str | list[str]: 24 | """generate a response from the model synchronously""" 25 | pass 26 | 27 | def evaluate(self, prompt: str, generation: str) -> list[float]: 28 | """evaluate the likelihoods of the generation synchronously""" 29 | raise NotImplementedError("evaluate method not implemented") 30 | 31 | 32 | class AbstractScalingResult(ABC): 33 | """abstract base class for scaling result""" 34 | 35 | @property 36 | @abstractmethod 37 | def the_one(self) -> str: 38 | """the selected response""" 39 | pass 40 | 41 | 42 | class AbstractScalingAlgorithm(ABC): 43 | """abstract base class for inference-time scaling algorithms""" 44 | 45 | @abstractmethod 46 | async def ainfer( 47 | self, 48 | lm: AbstractLanguageModel, 49 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 50 | budget: int, 51 | return_response_only: bool = True, 52 | tools: list[dict] | None = None, 53 | tool_choice: str | dict | None = None, 54 | ) -> str | AbstractScalingResult: 55 | """ 56 | Run inference asynchronously with the given language model and prompt. 57 | 58 | This is the primary method that subclasses must implement. 59 | """ 60 | pass 61 | 62 | def infer( 63 | self, 64 | lm: AbstractLanguageModel, 65 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 66 | budget: int, 67 | return_response_only: bool = True, 68 | tools: list[dict] | None = None, 69 | tool_choice: str | dict | None = None, 70 | ) -> str | AbstractScalingResult: 71 | """ 72 | Run inference synchronously with the given language model and prompt. 73 | 74 | Default implementation wraps ainfer() using asyncio.run(). 75 | """ 76 | import asyncio 77 | 78 | return asyncio.run( 79 | self.ainfer( 80 | lm, prompt_or_messages, budget, return_response_only, tools, tool_choice 81 | ) 82 | ) 83 | 84 | 85 | class AbstractOutcomeRewardModel(ABC): 86 | """abstract base class for outcome reward models""" 87 | 88 | @abstractmethod 89 | async def ascore( 90 | self, prompt_or_messages: str | list[ChatMessage] | ChatMessages, response: str 91 | ) -> float: 92 | """score a response asynchronously""" 93 | pass 94 | 95 | @abstractmethod 96 | def score( 97 | self, prompt_or_messages: str | list[ChatMessage] | ChatMessages, response: str 98 | ) -> float: 99 | """score a response synchronously""" 100 | pass 101 | 102 | 103 | # TODO(GX) deal with aggregation of PRM scores somehow in a common place, e.g. here 104 | class AbstractProcessRewardModel(ABC): 105 | """abstract base class for process reward models""" 106 | 107 | @abstractmethod 108 | async def ascore( 109 | self, 110 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 111 | steps: list[str], 112 | ) -> list[float]: 113 | """score steps asynchronously""" 114 | pass 115 | 116 | @abstractmethod 117 | def score( 118 | self, 119 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 120 | steps: list[str], 121 | ) -> list[float]: 122 | """score steps synchronously""" 123 | pass 124 | -------------------------------------------------------------------------------- /its_hub/algorithms/bon.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from its_hub.base import ( 4 | AbstractLanguageModel, 5 | AbstractOutcomeRewardModel, 6 | AbstractScalingAlgorithm, 7 | AbstractScalingResult, 8 | ) 9 | from its_hub.types import ChatMessage, ChatMessages 10 | from its_hub.utils import extract_content_from_lm_response 11 | 12 | 13 | def _dedupe_with_inverse(seq: list[str]) -> tuple[list[str], list[int]]: 14 | """ 15 | Deduplicate a sequence while preserving order and tracking original indices. 16 | 17 | Returns (uniques, inverse_idx) where: 18 | - uniques: list of unique items in order of first appearance 19 | - inverse_idx: for each item in seq, its index in the uniques list 20 | 21 | Example: 22 | seq = ["a", "b", "a", "c", "b"] 23 | returns (["a", "b", "c"], [0, 1, 0, 2, 1]) 24 | """ 25 | uniques: list[str] = [] 26 | index_of: dict[str, int] = {} 27 | inverse_idx: list[int] = [] 28 | 29 | for item in seq: 30 | j = index_of.get(item) 31 | if j is None: 32 | j = len(uniques) 33 | index_of[item] = j 34 | uniques.append(item) 35 | inverse_idx.append(j) 36 | 37 | return uniques, inverse_idx 38 | 39 | 40 | @dataclass 41 | class BestOfNResult(AbstractScalingResult): 42 | responses: list[dict] # Keep original message format with tool calls 43 | scores: list[float] 44 | selected_index: int 45 | 46 | @property 47 | def the_one(self) -> dict: 48 | return self.responses[self.selected_index] 49 | 50 | 51 | class BestOfN(AbstractScalingAlgorithm): 52 | def __init__(self, orm: AbstractOutcomeRewardModel): 53 | self.orm = orm 54 | 55 | async def ainfer( 56 | self, 57 | lm: AbstractLanguageModel, 58 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 59 | budget: int, 60 | return_response_only: bool = True, 61 | tools: list[dict] | None = None, 62 | tool_choice: str | dict | None = None, 63 | ) -> dict | BestOfNResult: 64 | """run inference asynchronously with best-of-n""" 65 | chat_messages = ChatMessages.from_prompt_or_messages(prompt_or_messages) 66 | 67 | # generate responses 68 | responses = await lm.agenerate( 69 | chat_messages.to_batch(budget), tools=tools, tool_choice=tool_choice 70 | ) 71 | 72 | # extract content from message dict responses 73 | response_contents = [extract_content_from_lm_response(r) for r in responses] 74 | 75 | # deduplicate responses to avoid redundant scoring 76 | unique_responses, inverse_idx = _dedupe_with_inverse(response_contents) 77 | 78 | # early return if all responses are identical - no need to score 79 | if len(unique_responses) == 1: 80 | scores = [1] * len(responses) 81 | result = BestOfNResult( 82 | responses=responses, 83 | scores=scores, 84 | selected_index=0, 85 | ) 86 | return result.the_one if return_response_only else result 87 | 88 | # score only unique responses 89 | # TODO: make batched a configurable parameter or remove non-batched branch 90 | # Currently hardcoded to True, will be addressed in future PR 91 | batched = True 92 | if batched: 93 | unique_scores = await self.orm.ascore(chat_messages, unique_responses) 94 | else: 95 | unique_scores = [] 96 | for r in unique_responses: 97 | unique_scores.append(await self.orm.ascore(chat_messages, r)) 98 | 99 | # map scores back to original response indices 100 | scores = [unique_scores[idx] for idx in inverse_idx] 101 | 102 | # select the best response 103 | selected_index = scores.index(max(scores)) 104 | 105 | # return the result - preserve original message format with tool calls 106 | result = BestOfNResult( 107 | responses=responses, # Keep original dict format with tool calls 108 | scores=scores, 109 | selected_index=selected_index, 110 | ) 111 | return result.the_one if return_response_only else result 112 | -------------------------------------------------------------------------------- /.devcontainer/init-firewall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail # Exit on error, undefined vars, and pipeline failures 3 | IFS=$'\n\t' # Stricter word splitting 4 | 5 | # Flush existing rules and delete existing ipsets 6 | iptables -F 7 | iptables -X 8 | iptables -t nat -F 9 | iptables -t nat -X 10 | iptables -t mangle -F 11 | iptables -t mangle -X 12 | ipset destroy allowed-domains 2>/dev/null || true 13 | 14 | # First allow DNS and localhost before any restrictions 15 | # Allow outbound DNS 16 | iptables -A OUTPUT -p udp --dport 53 -j ACCEPT 17 | # Allow inbound DNS responses 18 | iptables -A INPUT -p udp --sport 53 -j ACCEPT 19 | # Allow outbound SSH 20 | iptables -A OUTPUT -p tcp --dport 22 -j ACCEPT 21 | # Allow inbound SSH responses 22 | iptables -A INPUT -p tcp --sport 22 -m state --state ESTABLISHED -j ACCEPT 23 | # Allow localhost 24 | iptables -A INPUT -i lo -j ACCEPT 25 | iptables -A OUTPUT -o lo -j ACCEPT 26 | 27 | # Create ipset with CIDR support 28 | ipset create allowed-domains hash:net 29 | 30 | # Fetch GitHub meta information and aggregate + add their IP ranges 31 | echo "Fetching GitHub IP ranges..." 32 | gh_ranges=$(curl -s https://api.github.com/meta) 33 | if [ -z "$gh_ranges" ]; then 34 | echo "ERROR: Failed to fetch GitHub IP ranges" 35 | exit 1 36 | fi 37 | 38 | if ! echo "$gh_ranges" | jq -e '.web and .api and .git' >/dev/null; then 39 | echo "ERROR: GitHub API response missing required fields" 40 | exit 1 41 | fi 42 | 43 | echo "Processing GitHub IPs..." 44 | while read -r cidr; do 45 | if [[ ! "$cidr" =~ ^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}/[0-9]{1,2}$ ]]; then 46 | echo "ERROR: Invalid CIDR range from GitHub meta: $cidr" 47 | exit 1 48 | fi 49 | echo "Adding GitHub range $cidr" 50 | ipset add allowed-domains "$cidr" 51 | done < <(echo "$gh_ranges" | jq -r '(.web + .api + .git)[]' | aggregate -q) 52 | 53 | # Resolve and add other allowed domains 54 | for domain in \ 55 | "registry.npmjs.org" \ 56 | "api.anthropic.com" \ 57 | "sentry.io" \ 58 | "statsig.anthropic.com" \ 59 | "statsig.com"; do 60 | echo "Resolving $domain..." 61 | ips=$(dig +short A "$domain") 62 | if [ -z "$ips" ]; then 63 | echo "ERROR: Failed to resolve $domain" 64 | exit 1 65 | fi 66 | 67 | while read -r ip; do 68 | if [[ ! "$ip" =~ ^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$ ]]; then 69 | echo "ERROR: Invalid IP from DNS for $domain: $ip" 70 | exit 1 71 | fi 72 | echo "Adding $ip for $domain" 73 | ipset add allowed-domains "$ip" 74 | done < <(echo "$ips") 75 | done 76 | 77 | # Get host IP from default route 78 | HOST_IP=$(ip route | grep default | cut -d" " -f3) 79 | if [ -z "$HOST_IP" ]; then 80 | echo "ERROR: Failed to detect host IP" 81 | exit 1 82 | fi 83 | 84 | HOST_NETWORK=$(echo "$HOST_IP" | sed "s/\.[0-9]*$/.0\/24/") 85 | echo "Host network detected as: $HOST_NETWORK" 86 | 87 | # Set up remaining iptables rules 88 | iptables -A INPUT -s "$HOST_NETWORK" -j ACCEPT 89 | iptables -A OUTPUT -d "$HOST_NETWORK" -j ACCEPT 90 | 91 | # Set default policies to DROP first 92 | iptables -P INPUT DROP 93 | iptables -P FORWARD DROP 94 | iptables -P OUTPUT DROP 95 | 96 | # First allow established connections for already approved traffic 97 | iptables -A INPUT -m state --state ESTABLISHED,RELATED -j ACCEPT 98 | iptables -A OUTPUT -m state --state ESTABLISHED,RELATED -j ACCEPT 99 | 100 | # Then allow only specific outbound traffic to allowed domains 101 | iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT 102 | 103 | echo "Firewall configuration complete" 104 | echo "Verifying firewall rules..." 105 | if curl --connect-timeout 5 https://example.com >/dev/null 2>&1; then 106 | echo "ERROR: Firewall verification failed - was able to reach https://example.com" 107 | exit 1 108 | else 109 | echo "Firewall verification passed - unable to reach https://example.com as expected" 110 | fi 111 | 112 | # Verify GitHub API access 113 | if ! curl --connect-timeout 5 https://api.github.com/zen >/dev/null 2>&1; then 114 | echo "ERROR: Firewall verification failed - unable to reach https://api.github.com" 115 | exit 1 116 | else 117 | echo "Firewall verification passed - able to reach https://api.github.com as expected" 118 | fi 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `its-hub`: A Python library for inference-time scaling 2 | 3 | [![Tests](https://github.com/Red-Hat-AI-Innovation-Team/its_hub/actions/workflows/tests.yaml/badge.svg)](https://github.com/Red-Hat-AI-Innovation-Team/its_hub/actions/workflows/tests.yaml) 4 | [![codecov](https://codecov.io/gh/Red-Hat-AI-Innovation-Team/its_hub/graph/badge.svg?token=6WD8NB9YPN)](https://codecov.io/gh/Red-Hat-AI-Innovation-Team/its_hub) 5 | [![PyPI version](https://badge.fury.io/py/its-hub.svg)](https://badge.fury.io/py/its-hub) 6 | 7 | **its_hub** is a Python library for inference-time scaling of LLMs, focusing on mathematical reasoning tasks. 8 | 9 | ## 📚 Documentation 10 | 11 | For comprehensive documentation, including installation guides, tutorials, and API reference, visit: 12 | 13 | **[https://ai-innovation.team/its_hub](https://ai-innovation.team/its_hub)** 14 | 15 | ## Installation 16 | 17 | Choose the installation option based on which algorithms you need: 18 | 19 | ```bash 20 | # Core installation - includes: 21 | # - Best-of-N with LLM Judge 22 | # - Self-Consistency 23 | # - OpenAI-compatible language models 24 | pip install its_hub 25 | 26 | # Process Reward Model installation - adds: 27 | # - Particle Filtering Algorithms 28 | # - Beam Search 29 | # - LocalVllmProcessRewardModel 30 | # - Required for step-by-step reasoning with process reward models 31 | pip install its_hub[prm] 32 | 33 | # Development installation 34 | git clone https://github.com/Red-Hat-AI-Innovation-Team/its_hub.git 35 | cd its_hub 36 | pip install -e ".[dev]" 37 | ``` 38 | 39 | ## Quick Start 40 | 41 | ### Example 1: Best-of-N with LLM Judge 42 | 43 | **Installation required:** `pip install its_hub` (core) 44 | 45 | Use Best-of-N algorithm with an LLM judge for response selection - works with any OpenAI-compatible API: 46 | 47 | ```python 48 | from its_hub.lms import OpenAICompatibleLanguageModel 49 | from its_hub.algorithms import BestOfN 50 | from its_hub.integration.reward_hub import LLMJudgeRewardModel 51 | 52 | # Initialize language model 53 | lm = OpenAICompatibleLanguageModel( 54 | endpoint="https://api.openai.com/v1", 55 | api_key="your-api-key", 56 | model_name="gpt-4o-mini", 57 | ) 58 | 59 | # Set up LLM judge for scoring 60 | judge = LLMJudgeRewardModel( 61 | model="gpt-4o-mini", 62 | criterion="overall_quality", 63 | judge_type="groupwise", 64 | api_key="your-api-key", 65 | ) 66 | scaling_alg = BestOfN(judge) 67 | 68 | # Generate multiple responses and select the best 69 | result = scaling_alg.infer(lm, "Explain quantum entanglement in simple terms", budget=4) 70 | print(result) 71 | ``` 72 | 73 | ### Example 2: Particle Filtering with Process Reward Model 74 | 75 | **Installation required:** `pip install its_hub[prm]` 76 | 77 | Use Particle Filtering for step-by-step reasoning with process reward models: 78 | 79 | ```python 80 | from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT 81 | from its_hub.lms import OpenAICompatibleLanguageModel, StepGeneration 82 | from its_hub.algorithms import ParticleFiltering 83 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 84 | 85 | # Initialize language model (requires vLLM server running) 86 | lm = OpenAICompatibleLanguageModel( 87 | endpoint="http://localhost:8100/v1", 88 | api_key="NO_API_KEY", 89 | model_name="Qwen/Qwen2.5-Math-1.5B-Instruct", 90 | system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT, 91 | ) 92 | 93 | # Set up step generation and process reward model 94 | sg = StepGeneration(step_token="\n\n", max_steps=32, stop_token=r"\boxed") 95 | prm = LocalVllmProcessRewardModel( 96 | model_name="Qwen/Qwen2.5-Math-PRM-7B", 97 | device="cuda:0", 98 | aggregation_method="prod" 99 | ) 100 | scaling_alg = ParticleFiltering(sg, prm) 101 | 102 | # Solve with step-by-step reasoning 103 | result = scaling_alg.infer(lm, "Solve x^2 + 5x + 6 = 0", budget=8) 104 | print(result) 105 | ``` 106 | 107 | ## Key Features 108 | 109 | - 🔬 **Multiple Algorithms**: Particle Filtering, Best-of-N, Beam Search, Self-Consistency 110 | - 🚀 **OpenAI-Compatible API**: Easy integration with existing applications 111 | - 🧮 **Math-Optimized**: Built for mathematical reasoning with specialized prompts 112 | - 📊 **Benchmarking Tools**: Compare algorithms on MATH500 and AIME-2024 datasets 113 | - ⚡ **Async Support**: Concurrent generation with limits and error handling 114 | 115 | 116 | For detailed documentation, visit: [https://ai-innovation.team/its_hub](https://ai-innovation.team/its_hub) 117 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Auto generated 2 | its_hub/_version.py 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints/ 83 | 84 | # Don't ignore the generated notebooks in notebooks/ directory 85 | # (we want both .py and .ipynb versions) 86 | !notebooks/**/*.ipynb 87 | 88 | # But ignore notebooks in other directories 89 | *.ipynb 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # UV 108 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | uv.lock 112 | 113 | # poetry 114 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 115 | # This is especially recommended for binary packages to ensure reproducibility, and is more 116 | # commonly ignored for libraries. 117 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 118 | #poetry.lock 119 | 120 | # pdm 121 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 122 | #pdm.lock 123 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 124 | # in version control. 125 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 126 | .pdm.toml 127 | .pdm-python 128 | .pdm-build/ 129 | 130 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 131 | __pypackages__/ 132 | 133 | # Celery stuff 134 | celerybeat-schedule 135 | celerybeat.pid 136 | 137 | # SageMath parsed files 138 | *.sage.py 139 | 140 | # Environments 141 | .env 142 | .venv 143 | env/ 144 | venv/ 145 | ENV/ 146 | env.bak/ 147 | venv.bak/ 148 | 149 | # Spyder project settings 150 | .spyderproject 151 | .spyproject 152 | 153 | # Rope project settings 154 | .ropeproject 155 | 156 | # mkdocs documentation 157 | /site 158 | 159 | # mypy 160 | .mypy_cache/ 161 | .dmypy.json 162 | dmypy.json 163 | 164 | # Pyre type checker 165 | .pyre/ 166 | 167 | # pytype static type analyzer 168 | .pytype/ 169 | 170 | # Cython debug symbols 171 | cython_debug/ 172 | 173 | # PyCharm 174 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 175 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 176 | # and can be added to the global gitignore or merged into this file. For a more nuclear 177 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 178 | #.idea/ 179 | 180 | # Ruff stuff: 181 | .ruff_cache/ 182 | 183 | # PyPI configuration file 184 | .pypirc 185 | 186 | # local dev 187 | notebooks/dev.ipynb 188 | results/ -------------------------------------------------------------------------------- /tests/mocks/reward_models.py: -------------------------------------------------------------------------------- 1 | """Mock reward models for testing.""" 2 | 3 | from its_hub.base import AbstractOutcomeRewardModel 4 | 5 | 6 | class MockOutcomeRewardModel(AbstractOutcomeRewardModel): 7 | """Mock outcome reward model with configurable scores.""" 8 | 9 | def __init__(self, scores: list[float] | float): 10 | if isinstance(scores, float): 11 | self.scores = [scores] 12 | else: 13 | self.scores = scores 14 | self.call_count = 0 # tracks total number of individual scores returned 15 | self.score_call_count = 0 # tracks number of times score() is called 16 | 17 | async def ascore(self, prompt: str, response: str | list[str]) -> float | list[float]: 18 | return self.score(prompt, response) 19 | 20 | def score(self, prompt: str, response: str | list[str]) -> float | list[float]: 21 | self.score_call_count += 1 22 | if isinstance(response, list): 23 | scores = [] 24 | for i in range(len(response)): 25 | score_idx = (self.call_count + i) % len(self.scores) 26 | scores.append(self.scores[score_idx]) 27 | self.call_count += len(response) 28 | return scores 29 | else: 30 | score = self.scores[self.call_count % len(self.scores)] 31 | self.call_count += 1 32 | return score 33 | 34 | 35 | class MockProcessRewardModel: 36 | """Mock process reward model with configurable scores.""" 37 | 38 | def __init__(self, scores: list[float] | list[list[float]]): 39 | if isinstance(scores[0], list): 40 | # Flatten nested lists 41 | self.scores = [score for sublist in scores for score in sublist] 42 | else: 43 | self.scores = scores 44 | self.call_count = 0 45 | 46 | async def ascore(self, prompt: str, response: str | list[str]) -> float | list[float]: 47 | return self.score(prompt, response) 48 | 49 | def score(self, prompt: str, response: str | list[str]) -> float | list[float]: 50 | if isinstance(response, list): 51 | scores = [] 52 | for i in range(len(response)): 53 | score_idx = (self.call_count + i) % len(self.scores) 54 | scores.append(self.scores[score_idx]) 55 | self.call_count += len(response) 56 | return scores 57 | else: 58 | score = self.scores[self.call_count % len(self.scores)] 59 | self.call_count += 1 60 | return score 61 | 62 | 63 | class HighVarianceRewardModel: 64 | """Mock reward model with high variance for testing edge cases.""" 65 | 66 | def __init__(self): 67 | self.scores = [0.0, 1.0, 0.5, 0.1, 0.9, 0.3, 0.7, 0.2, 0.8, 0.4] 68 | self.call_count = 0 69 | 70 | async def ascore(self, prompt: str, response: str | list[str]) -> float | list[float]: 71 | return self.score(prompt, response) 72 | 73 | def score(self, prompt: str, response: str | list[str]) -> float | list[float]: 74 | if isinstance(response, list): 75 | scores = [] 76 | for i in range(len(response)): 77 | score_idx = (self.call_count + i) % len(self.scores) 78 | scores.append(self.scores[score_idx]) 79 | self.call_count += len(response) 80 | return scores 81 | else: 82 | score = self.scores[self.call_count % len(self.scores)] 83 | self.call_count += 1 84 | return score 85 | 86 | 87 | class ErrorRewardModel: 88 | """Mock reward model that can simulate errors.""" 89 | 90 | def __init__(self, scores: list[float], error_on_calls: list[int] | None = None): 91 | self.scores = scores 92 | self.error_on_calls = error_on_calls or [] 93 | self.call_count = 0 94 | 95 | async def ascore(self, prompt: str, response: str | list[str]) -> float | list[float]: 96 | return self.score(prompt, response) 97 | 98 | def score(self, prompt: str, response: str | list[str]) -> float | list[float]: 99 | if self.call_count in self.error_on_calls: 100 | self.call_count += 1 101 | raise Exception("Simulated reward model error") 102 | 103 | if isinstance(response, list): 104 | scores = [] 105 | for i in range(len(response)): 106 | if (self.call_count + i) in self.error_on_calls: 107 | raise Exception("Simulated reward model error in batch") 108 | score_idx = (self.call_count + i) % len(self.scores) 109 | scores.append(self.scores[score_idx]) 110 | self.call_count += len(response) 111 | return scores 112 | else: 113 | score = self.scores[self.call_count % len(self.scores)] 114 | self.call_count += 1 115 | return score 116 | -------------------------------------------------------------------------------- /tests/mocks/test_data.py: -------------------------------------------------------------------------------- 1 | """Test data factories for consistent test data generation.""" 2 | 3 | from typing import Any 4 | 5 | from its_hub.types import ChatMessage, ChatMessages 6 | 7 | 8 | class TestDataFactory: 9 | """Factory for creating consistent test data.""" 10 | 11 | @staticmethod 12 | def create_chat_messages( 13 | user_content: str = "Hello", system_content: str | None = None 14 | ) -> ChatMessages: 15 | """Create ChatMessages object for testing.""" 16 | if system_content: 17 | messages = [ 18 | ChatMessage(role="system", content=system_content), 19 | ChatMessage(role="user", content=user_content), 20 | ] 21 | return ChatMessages(messages) 22 | else: 23 | return ChatMessages(user_content) # Simple string case 24 | 25 | @staticmethod 26 | def create_chat_completion_request( 27 | model: str = "test-model", 28 | user_content: str = "Hello", 29 | budget: int = 4, 30 | system_content: str | None = None, 31 | **kwargs, 32 | ) -> dict[str, Any]: 33 | """Create a standard chat completion request.""" 34 | chat_messages = TestDataFactory.create_chat_messages( 35 | user_content, system_content 36 | ).to_chat_messages() 37 | request = { 38 | "model": model, 39 | "messages": [ 40 | msg.__dict__ for msg in chat_messages 41 | ], # Convert to dict for JSON serialization 42 | "budget": budget, 43 | } 44 | request.update(kwargs) 45 | return request 46 | 47 | @staticmethod 48 | def create_config_request( 49 | endpoint: str = "http://localhost:8000", 50 | api_key: str = "test-key", 51 | model: str = "test-model", 52 | alg: str = "best-of-n", 53 | **kwargs, 54 | ) -> dict[str, Any]: 55 | """Create a standard configuration request.""" 56 | config = { 57 | "endpoint": endpoint, 58 | "api_key": api_key, 59 | "model": model, 60 | "alg": alg, 61 | "rm_name": "test-rm", 62 | "rm_device": "cpu", 63 | } 64 | config.update(kwargs) 65 | return config 66 | 67 | @staticmethod 68 | def create_error_trigger_request(trigger: str = "trigger_error") -> dict[str, Any]: 69 | """Create a request that triggers errors for testing.""" 70 | return TestDataFactory.create_chat_completion_request(user_content=trigger) 71 | 72 | @staticmethod 73 | def create_multiple_responses(base: str = "response", count: int = 3) -> list[str]: 74 | """Create multiple response strings for testing.""" 75 | return [f"{base}{i + 1}" for i in range(count)] 76 | 77 | @staticmethod 78 | def create_score_sequence( 79 | base_score: float = 0.5, count: int = 3, increment: float = 0.1 80 | ) -> list[float]: 81 | """Create a sequence of scores for testing.""" 82 | return [base_score + (i * increment) for i in range(count)] 83 | 84 | 85 | # Common test scenarios 86 | TEST_SCENARIOS = { 87 | "simple_chat": { 88 | "user_content": "Hello, world!", 89 | "expected_response": { 90 | "role": "assistant", 91 | "content": "Response to: Hello, world!", 92 | }, 93 | }, 94 | "math_problem": { 95 | "user_content": "Solve 2+2", 96 | "expected_response": {"role": "assistant", "content": "Response to: Solve 2+2"}, 97 | }, 98 | "error_trigger": {"user_content": "trigger_error", "should_error": True}, 99 | "vllm_error_trigger": {"user_content": "error", "should_error": True}, 100 | "with_system_prompt": { 101 | "system_content": "You are a helpful assistant", 102 | "user_content": "How can I help you?", 103 | "expected_response": { 104 | "role": "assistant", 105 | "content": "Response to: How can I help you?", 106 | }, 107 | }, 108 | } 109 | 110 | # Algorithm test configurations 111 | ALGORITHM_CONFIGS = { 112 | "best_of_n": { 113 | "alg": "best-of-n", 114 | "requires_outcome_rm": True, 115 | "supports_batching": True, 116 | }, 117 | "beam_search": { 118 | "alg": "beam-search", 119 | "requires_process_rm": True, 120 | "requires_step_generation": True, 121 | "budget_constraints": "divisible_by_beam_width", 122 | }, 123 | "particle_filtering": { 124 | "alg": "particle-filtering", 125 | "requires_process_rm": True, 126 | "requires_step_generation": True, 127 | "supports_selection_methods": ["argmax", "sample"], 128 | }, 129 | "particle_gibbs": { 130 | "alg": "particle-gibbs", 131 | "requires_process_rm": True, 132 | "requires_step_generation": True, 133 | "supports_selection_methods": ["argmax", "sample"], 134 | "budget_constraints": "divisible_by_iterations", 135 | }, 136 | } 137 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Prerequisites 4 | 5 | - Python 3.10+ (3.11+ recommended) 6 | - pip or uv package manager 7 | - GPU with CUDA 11.8+ (only for `[prm]` installation) 8 | 9 | ## Installation Options 10 | 11 | | Option | Command | Use Case | 12 | |--------|---------|----------| 13 | | **Core** | `pip install its_hub` | Best-of-N, Self-Consistency, cloud APIs | 14 | | **PRM** | `pip install its_hub[prm]` | Particle Filtering, Beam Search, local reward models | 15 | | **Cloud** | `pip install its_hub[cloud]` | AWS Bedrock, Google Vertex AI | 16 | | **Research** | `pip install its_hub[research]` | Benchmarks, evaluation tools | 17 | | **Dev** | `pip install -e ".[dev]"` | Contributing, testing | 18 | 19 | --- 20 | 21 | ## Core Installation 22 | 23 | ```bash 24 | pip install its_hub 25 | ``` 26 | 27 | ### What's Included 28 | 29 | **Algorithms**: Best-of-N, Self-Consistency, LLM Judge 30 | **Language Models**: OpenAI-compatible, LiteLLM (100+ providers) 31 | **Key Dependencies**: `openai`, `litellm`, `reward-hub`, `transformers`, `fastapi` 32 | 33 | ### When to Use 34 | 35 | **Use if**: Working with cloud APIs (OpenAI, Anthropic, etc.), no GPU needed 36 | **Skip if**: Need Particle Filtering/Beam Search or local process reward models 37 | 38 | ### Under the Hood 39 | 40 | - **Size**: ~50MB (no vLLM or CUDA dependencies) 41 | - **Installation time**: 1-2 minutes 42 | - **GPU required**: No 43 | - **What's excluded**: vLLM, local reward model inference 44 | 45 | ```python 46 | # Verify installation 47 | from its_hub.algorithms import BestOfN, SelfConsistency 48 | from its_hub.integration.reward_hub import LLMJudgeRewardModel 49 | ``` 50 | 51 | --- 52 | 53 | ## Process Reward Model (PRM) Installation 54 | 55 | ```bash 56 | pip install its_hub[prm] 57 | ``` 58 | 59 | ### What's Added 60 | 61 | **Algorithms**: Particle Filtering, Beam Search (+ all core algorithms) 62 | **Reward Models**: `LocalVllmProcessRewardModel` for step-by-step scoring 63 | **Additional Dependencies**: `reward-hub[prm]` (includes vLLM with pinned versions) 64 | 65 | ### When to Use 66 | 67 | **Use if**: Need step-by-step reasoning with local reward models, have GPU 68 | **Skip if**: Only using cloud APIs or outcome-based scoring 69 | 70 | ### Under the Hood 71 | 72 | - **Size**: ~2-3GB (includes vLLM + CUDA dependencies) 73 | - **Installation time**: 5-10 minutes 74 | - **GPU required**: Yes (10-20GB VRAM for typical 7B reward models) 75 | - **Version pinning**: `reward-hub[prm]` pins compatible vLLM + transformers + PyTorch versions 76 | 77 | ```python 78 | # Verify installation 79 | from its_hub.algorithms import ParticleFiltering, BeamSearch 80 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 81 | 82 | # Check GPU 83 | import torch 84 | print(f'CUDA available: {torch.cuda.is_available()}') 85 | ``` 86 | 87 | --- 88 | 89 | ## Cloud Installation 90 | 91 | ```bash 92 | pip install its_hub[cloud] 93 | ``` 94 | 95 | **Adds**: AWS Bedrock (`boto3`) and Google Vertex AI (`google-cloud-aiplatform`) SDKs 96 | **Use if**: Need direct SDK access to Bedrock or Vertex AI (most cloud providers work with core via LiteLLM) 97 | 98 | --- 99 | 100 | ## Research Installation 101 | 102 | ```bash 103 | pip install its_hub[research] 104 | ``` 105 | 106 | **Adds**: `math-verify`, `datasets`, `matplotlib` 107 | **Use if**: Running benchmarks on MATH500/AIME or evaluating algorithm performance 108 | **Includes**: Benchmark scripts in `scripts/benchmark.py` 109 | 110 | --- 111 | 112 | ## Development Installation 113 | 114 | ```bash 115 | git clone https://github.com/Red-Hat-AI-Innovation-Team/its_hub.git 116 | cd its_hub 117 | 118 | # Recommended: uv 119 | uv sync --extra dev 120 | 121 | # Alternative: pip 122 | pip install -e ".[dev]" 123 | ``` 124 | 125 | **Includes**: All core + PRM + `pytest`, `ruff`, `jupyter`, notebooks 126 | **Use if**: Contributing, testing, or developing new features 127 | 128 | ```bash 129 | # Run tests 130 | uv run pytest tests/ 131 | uv run pytest tests/ --cov=its_hub 132 | 133 | # Code quality 134 | uv run ruff check its_hub/ --fix 135 | uv run ruff format its_hub/ 136 | ``` 137 | 138 | --- 139 | 140 | ## Combining Extras 141 | 142 | ```bash 143 | pip install its_hub[prm,research] # PRM + benchmarking 144 | pip install its_hub[cloud,research] # Cloud + benchmarking 145 | pip install -e ".[dev,research,cloud]" # Everything 146 | ``` 147 | 148 | --- 149 | 150 | ## Verification 151 | 152 | ```bash 153 | # Core 154 | python -c "from its_hub.algorithms import BestOfN; print('✅ Core OK')" 155 | 156 | # PRM 157 | python -c "from its_hub.algorithms import ParticleFiltering; print('✅ PRM OK')" 158 | python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')" 159 | ``` 160 | 161 | --- 162 | 163 | ## Next Steps 164 | 165 | - [Quick Start Guide](quick-start.md) - Best-of-N and Particle Filtering examples 166 | - [IaaS Service Guide](iaas-service.md) - Deploy as OpenAI-compatible API 167 | - [Development Guide](development.md) - Contributing guidelines 168 | 169 | For runtime issues (CUDA OOM, server errors, etc.), see the troubleshooting sections in the Quick Start or IaaS Service guides. -------------------------------------------------------------------------------- /tests/mocks/language_models.py: -------------------------------------------------------------------------------- 1 | """Mock language models for testing.""" 2 | 3 | from its_hub.base import AbstractLanguageModel 4 | 5 | 6 | class SimpleMockLanguageModel: 7 | """Simple mock language model for basic testing.""" 8 | 9 | def __init__(self, responses: list[str]): 10 | self.responses = responses 11 | self.call_count = 0 12 | 13 | async def agenerate(self, messages, **kwargs): 14 | return self.generate(messages, **kwargs) 15 | 16 | def generate(self, messages, **kwargs): 17 | if isinstance(messages[0], list): 18 | # Multiple message lists 19 | content_responses = self.responses[ 20 | self.call_count : self.call_count + len(messages) 21 | ] 22 | self.call_count += len(messages) 23 | return [ 24 | {"role": "assistant", "content": content} 25 | for content in content_responses 26 | ] 27 | else: 28 | # Single message list 29 | content = self.responses[self.call_count] 30 | self.call_count += 1 31 | return {"role": "assistant", "content": content} 32 | 33 | 34 | class StepMockLanguageModel(AbstractLanguageModel): 35 | """Mock language model for step-by-step generation testing.""" 36 | 37 | def __init__(self, step_responses: list[str]): 38 | self.step_responses = step_responses 39 | self.call_count = 0 40 | 41 | async def agenerate(self, messages, stop=None, max_tokens=None, temperature=None, include_stop_str_in_output=None, tools=None, tool_choice=None): 42 | return self.generate(messages, stop, max_tokens, temperature, include_stop_str_in_output, tools, tool_choice) 43 | 44 | def generate(self, messages, stop=None, max_tokens=None, temperature=None, include_stop_str_in_output=None, tools=None, tool_choice=None): 45 | if isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list): 46 | # Batched generation 47 | num_requests = len(messages) 48 | responses = [] 49 | for i in range(num_requests): 50 | response_idx = (self.call_count + i) % len(self.step_responses) 51 | content = self.step_responses[response_idx] 52 | responses.append({"role": "assistant", "content": content}) 53 | self.call_count += num_requests 54 | return responses 55 | else: 56 | # Single generation 57 | content = self.step_responses[self.call_count % len(self.step_responses)] 58 | self.call_count += 1 59 | return {"role": "assistant", "content": content} 60 | 61 | async def aevaluate(self, prompt: str, generation: str) -> list[float]: 62 | return self.evaluate(prompt, generation) 63 | 64 | def evaluate(self, prompt: str, generation: str) -> list[float]: 65 | """Return mock evaluation scores.""" 66 | return [0.1] * len(generation.split()) 67 | 68 | 69 | class ErrorMockLanguageModel(AbstractLanguageModel): 70 | """Mock language model that can simulate errors.""" 71 | 72 | def __init__(self, responses: list[str], error_on_calls: list[int] | None = None): 73 | self.responses = responses 74 | self.error_on_calls = error_on_calls or [] 75 | self.call_count = 0 76 | 77 | async def agenerate(self, messages, stop=None, max_tokens=None, temperature=None, include_stop_str_in_output=None, tools=None, tool_choice=None): 78 | return self.generate(messages, stop, max_tokens, temperature, include_stop_str_in_output, tools, tool_choice) 79 | 80 | def generate(self, messages, stop=None, max_tokens=None, temperature=None, include_stop_str_in_output=None, tools=None, tool_choice=None): 81 | if self.call_count in self.error_on_calls: 82 | self.call_count += 1 83 | raise Exception("Simulated LM error") 84 | 85 | if ( 86 | isinstance(messages, list) 87 | and len(messages) > 0 88 | and isinstance(messages[0], list) 89 | ): 90 | # Batched generation 91 | num_requests = len(messages) 92 | responses = [] 93 | for i in range(num_requests): 94 | if (self.call_count + i) in self.error_on_calls: 95 | raise Exception("Simulated LM error in batch") 96 | response_idx = (self.call_count + i) % len(self.responses) 97 | content = self.responses[response_idx] 98 | responses.append({"role": "assistant", "content": content}) 99 | self.call_count += num_requests 100 | return responses 101 | else: 102 | # Single generation 103 | content = self.responses[self.call_count % len(self.responses)] 104 | self.call_count += 1 105 | return {"role": "assistant", "content": content} 106 | 107 | async def aevaluate(self, prompt: str, generation: str) -> list[float]: 108 | return self.evaluate(prompt, generation) 109 | 110 | def evaluate(self, prompt: str, generation: str) -> list[float]: 111 | return [0.1] * len(generation.split()) 112 | -------------------------------------------------------------------------------- /notebooks/self-consistency.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # cell_metadata_filter: -all 5 | # notebook_metadata_filter: all 6 | # text_representation: 7 | # extension: .py 8 | # format_name: percent 9 | # format_version: '1.3' 10 | # jupytext_version: 1.18.1 11 | # kernelspec: 12 | # display_name: inference_time_scaling-dev 13 | # language: python 14 | # name: python3 15 | # language_info: 16 | # codemirror_mode: 17 | # name: ipython 18 | # version: 3 19 | # file_extension: .py 20 | # mimetype: text/x-python 21 | # name: python 22 | # nbconvert_exporter: python 23 | # pygments_lexer: ipython3 24 | # version: 3.11.11 25 | # --- 26 | 27 | # %% [markdown] 28 | # # Self-Consistency Algorithm Demo 29 | # This notebook demonstrates the Self-Consistency algorithm for mathematical reasoning. 30 | 31 | # %% 32 | # %load_ext autoreload 33 | # %autoreload 2 34 | 35 | # %% 36 | import os 37 | 38 | import nest_asyncio 39 | from dotenv import load_dotenv 40 | 41 | from its_hub.lms import OpenAICompatibleLanguageModel 42 | from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT 43 | 44 | nest_asyncio.apply() 45 | 46 | # Load environment variables from .env file 47 | load_dotenv() 48 | 49 | # Main example: OpenAI API endpoint with gpt-4o-mini 50 | lm = OpenAICompatibleLanguageModel( 51 | endpoint="https://api.openai.com/v1", 52 | api_key=os.getenv("OPENAI_API_KEY"), # Load API key from environment 53 | model_name="gpt-4o-mini", 54 | system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT, 55 | is_async=True, 56 | ) 57 | # %% 58 | # Alternative: vLLM local endpoint (commented out) 59 | # lm = OpenAICompatibleLanguageModel( 60 | # endpoint="http://localhost:8000/v1", 61 | # api_key="NO_API_KEY", 62 | # model_name="qwen2-math-1.5b-instruct", 63 | # system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT, 64 | # is_async=True, 65 | # ) 66 | 67 | # %% 68 | # Mathematical problem to solve 69 | prompt = r"Let $a$ be a positive real number such that all the roots of \[x^3 + ax^2 + ax + 1 = 0\]are real. Find the smallest possible value of $a.$" 70 | 71 | # Generate response using the proper format 72 | from its_hub.types import ChatMessages 73 | 74 | chat_messages = ChatMessages.from_prompt_or_messages(prompt) 75 | response = lm.generate(chat_messages.to_batch(1))[0] 76 | 77 | print(response) 78 | 79 | 80 | # %% 81 | def extract_boxed(s: str) -> str: 82 | import re 83 | # find all occurrences of \boxed{...} 84 | boxed_matches = re.findall(r'\\boxed\{([^{}]+(?:\{[^{}]*\}[^{}]*)*)\}', s) 85 | # return the last match if any were found 86 | return boxed_matches[-1] if boxed_matches else "" 87 | 88 | print(extract_boxed(response['content'])) 89 | 90 | # %% [markdown] 91 | # ## Self-Consistency Algorithm 92 | # Now we'll use the Self-Consistency algorithm to improve the answer quality. 93 | 94 | # %% 95 | from its_hub.algorithms import SelfConsistency 96 | 97 | # Set computational budget for scaling 98 | budget = 4 99 | 100 | scaling_alg = SelfConsistency(extract_boxed) 101 | 102 | scaling_result = scaling_alg.infer( 103 | lm, prompt, budget, return_response_only=False 104 | ) 105 | 106 | print("######## Self-Consistency Result ########") 107 | print(scaling_result.the_one) 108 | 109 | # %% 110 | print("######## Extracted Response Counts ########") 111 | print(scaling_result.response_counts) 112 | 113 | # %% 114 | 115 | 116 | # %% [markdown] 117 | # ## Self-Consistency Algorithm for Tool Calls 118 | # We have hierarchical tool-voting support in Self-Consistency algorithm 119 | # It first votes on tool names, and then on tool arguments. 120 | 121 | # %% 122 | from its_hub.types import ChatMessage, ChatMessages 123 | 124 | # Tool schema (OpenAI-style dicts) 125 | tools = [ 126 | { 127 | "type": "function", 128 | "function": { 129 | "name": "calculator", 130 | "description": "Perform arithmetic calculations", 131 | "parameters": { 132 | "type": "object", 133 | "properties": { 134 | "expression": { 135 | "type": "string", 136 | "description": "Mathematical expression to evaluate" 137 | } 138 | }, 139 | "required": ["expression"] 140 | } 141 | } 142 | } 143 | ] 144 | 145 | # ChatMessages instance with system + user 146 | tool_call_messages = ChatMessages([ 147 | ChatMessage( 148 | role="system", 149 | content="You are a precise calculator. Always use the calculator tool for arithmetic and format your final answer as \\boxed{result}." 150 | ), 151 | ChatMessage( 152 | role="user", 153 | content="What is 847 * 293 + 156?" 154 | ), 155 | ]) 156 | 157 | # %% 158 | # Use hierarchical tool voting 159 | scaling_alg_tool = SelfConsistency(tool_vote="tool_hierarchical") 160 | 161 | budget = 5 162 | scaling_result = scaling_alg_tool.infer( 163 | lm, tool_call_messages, budget, return_response_only=False, tools=tools, tool_choice="auto" 164 | ) 165 | 166 | # %% 167 | print("######## Self-Consistency Result ########") 168 | print(scaling_result.the_one) 169 | 170 | print("######## Tool Call Response Counts ########") 171 | print(scaling_result.response_counts) 172 | 173 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | name: Build, test, and upload PyPI package 4 | 5 | on: 6 | push: 7 | branches: 8 | - "main" 9 | - "release-**" 10 | tags: 11 | - "v*" 12 | pull_request: 13 | branches: 14 | - "main" 15 | - "release-**" 16 | release: 17 | types: 18 | - published 19 | 20 | env: 21 | LC_ALL: en_US.UTF-8 22 | 23 | defaults: 24 | run: 25 | shell: bash 26 | 27 | permissions: 28 | contents: read 29 | 30 | jobs: 31 | # Create and verify release artifacts 32 | # - build source dist (tar ball) and wheel 33 | # - validate artifacts with various tools 34 | # - upload artifacts to GHA 35 | build-package: 36 | name: Build and check packages 37 | runs-on: ubuntu-latest 38 | steps: 39 | - name: "Harden Runner" 40 | uses: step-security/harden-runner@c6295a65d1254861815972266d5933fd6e532bdf # v2.11.1 41 | with: 42 | egress-policy: audit # TODO: change to 'egress-policy: block' after couple of runs 43 | 44 | 45 | - name: "Checkout" 46 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 47 | with: 48 | # for setuptools-scm 49 | fetch-depth: 0 50 | 51 | - name: "Build and Inspect" 52 | uses: hynek/build-and-inspect-python-package@efb823f52190ad02594531168b7a2d5790e66516 # v2.14.0 53 | 54 | # push to Test PyPI on 55 | # - a new GitHub release is published 56 | # - a PR is merged into main branch 57 | publish-test-pypi: 58 | name: Publish packages to test.pypi.org 59 | # environment: publish-test-pypi 60 | if: ${{ (github.repository_owner == 'Red-Hat-AI-Innovation-Team') && ((github.event.action == 'published') || ((github.event_name == 'push') && (github.ref == 'refs/heads/main'))) }} 61 | permissions: 62 | contents: read 63 | # see https://docs.pypi.org/trusted-publishers/ 64 | id-token: write 65 | runs-on: ubuntu-latest 66 | needs: build-package 67 | 68 | environment: 69 | name: testpypi 70 | url: https://test.pypi.org/p/its-hub 71 | 72 | steps: 73 | - name: "Harden Runner" 74 | uses: step-security/harden-runner@c6295a65d1254861815972266d5933fd6e532bdf # v2.11.1 75 | with: 76 | egress-policy: audit # TODO: change to 'egress-policy: block' after couple of runs 77 | 78 | - name: "Download build artifacts" 79 | uses: actions/download-artifact@cc203385981b70ca67e1cc392babf9cc229d5806 # v4.1.9 80 | with: 81 | name: Packages 82 | path: dist 83 | 84 | - name: "Upload to Test PyPI" 85 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 86 | with: 87 | repository-url: https://test.pypi.org/legacy/ 88 | 89 | # push to Production PyPI on 90 | # - a new GitHub release is published 91 | publish-pypi: 92 | name: Publish release to pypi.org 93 | # environment: publish-pypi 94 | if: ${{ (github.repository_owner == 'Red-Hat-AI-Innovation-Team') && (github.event.action == 'published') }} 95 | permissions: 96 | # see https://docs.pypi.org/trusted-publishers/ 97 | id-token: write 98 | # allow gh release upload 99 | contents: write 100 | 101 | environment: 102 | name: pypi 103 | url: https://pypi.org/p/its-hub 104 | 105 | runs-on: ubuntu-latest 106 | needs: build-package 107 | 108 | steps: 109 | - name: "Harden Runner" 110 | uses: step-security/harden-runner@c6295a65d1254861815972266d5933fd6e532bdf # v2.11.1 111 | with: 112 | egress-policy: audit # TODO: change to 'egress-policy: block' after couple of runs 113 | 114 | - name: "Download build artifacts" 115 | uses: actions/download-artifact@cc203385981b70ca67e1cc392babf9cc229d5806 # v4.1.9 116 | with: 117 | name: Packages 118 | path: dist 119 | 120 | - name: "Sigstore sign package" 121 | uses: sigstore/gh-action-sigstore-python@f514d46b907ebcd5bedc05145c03b69c1edd8b46 # v3.0.0 122 | with: 123 | inputs: | 124 | ./dist/*.tar.gz 125 | ./dist/*.whl 126 | release-signing-artifacts: false 127 | 128 | - name: "Upload artifacts and signatures to GitHub release" 129 | run: | 130 | gh release upload '${{ github.ref_name }}' dist/* --repo '${{ github.repository }}' 131 | env: 132 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 133 | 134 | # PyPI does not accept .sigstore artifacts and 135 | # gh-action-pypi-publish has no option to ignore them. 136 | - name: "Remove sigstore signatures before uploading to PyPI" 137 | run: | 138 | rm ./dist/*.sigstore.json 139 | 140 | - name: "Upload to PyPI" 141 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 142 | -------------------------------------------------------------------------------- /docs/PLANNING_WRAPPER.md: -------------------------------------------------------------------------------- 1 | # Planning Wrapper for ITS Algorithms 2 | 3 | ## Overview 4 | 5 | The `PlanningWrapper` is a generic enhancement that adds a planning phase to any inference-time scaling (ITS) algorithm. It allows the model to generate multiple solution approaches before execution, potentially improving performance through diverse strategy exploration. 6 | 7 | ## Key Features 8 | 9 | - **Universal Compatibility**: Works with any ITS algorithm (Self-Consistency, Best-of-N, Particle Filtering, Beam Search) 10 | - **Unified Interface**: Maintains the same `infer()` method signature across all enhanced algorithms 11 | - **Smart Budget Allocation**: Automatically divides computational budget across planned approaches 12 | - **Robust Plan Parsing**: Handles various plan formats with intelligent fallbacks 13 | 14 | ## Architecture 15 | 16 | ### Core Components 17 | 18 | 1. **PlanningWrapper**: Main class that wraps any base ITS algorithm 19 | 2. **PlanningPromptTemplate**: Generates prompts that encourage diverse approach planning 20 | 3. **PlanParser**: Extracts structured approaches from natural language plans 21 | 4. **ApproachPromptTemplate**: Creates approach-specific prompts for execution 22 | 23 | ### Process Flow 24 | 25 | 1. **Planning Phase**: Generate plan with 3 distinct approaches (costs 1 from budget) 26 | 2. **Approach Parsing**: Extract approaches using regex patterns with fallbacks 27 | 3. **Budget Allocation**: Divide remaining budget equally across approaches 28 | 4. **Execution**: Run base algorithm for each approach with approach-specific prompts 29 | 5. **Selection**: Choose best result based on algorithm-specific scoring 30 | 31 | ## Usage 32 | 33 | ### Manual Wrapping 34 | ```python 35 | from its_hub.algorithms.planning_wrapper import PlanningWrapper 36 | from its_hub.algorithms import SelfConsistency 37 | 38 | base_algorithm = SelfConsistency(extract_fn) 39 | planning_algorithm = PlanningWrapper(base_algorithm) 40 | 41 | result = planning_algorithm.infer(lm, prompt, budget=16, return_response_only=False) 42 | ``` 43 | 44 | ### Convenience Functions 45 | ```python 46 | from its_hub.algorithms.planning_wrapper import ( 47 | create_planning_self_consistency, 48 | create_planning_particle_filtering, 49 | create_planning_best_of_n, 50 | create_planning_beam_search 51 | ) 52 | 53 | # Enhanced algorithms 54 | planning_sc = create_planning_self_consistency(extract_fn) 55 | planning_pf = create_planning_particle_filtering(sg, prm) 56 | planning_bon = create_planning_best_of_n(orm) 57 | planning_bs = create_planning_beam_search(sg, prm, beam_width=4) 58 | 59 | # Same interface for all 60 | result = planning_sc.infer(lm, prompt, budget=16, return_response_only=False) 61 | ``` 62 | 63 | ### Result Object 64 | ```python 65 | # Planning-enhanced results include additional information 66 | result = planning_algorithm.infer(lm, prompt, budget=16, return_response_only=False) 67 | 68 | print(f"Best answer: {result.the_one}") 69 | print(f"Generated plan: {result.plan}") 70 | print(f"Approaches used: {result.approaches}") 71 | print(f"Best approach: {result.best_approach}") 72 | print(f"Budget allocation: {result.approach_budgets}") 73 | ``` 74 | 75 | ## Supported Algorithms 76 | 77 | - ✅ **Self-Consistency**: Enhanced with planning via `create_planning_self_consistency()` 78 | - ✅ **Best-of-N**: Enhanced with planning via `create_planning_best_of_n()` 79 | - ✅ **Particle Filtering**: Enhanced with planning via `create_planning_particle_filtering()` 80 | - ✅ **Beam Search**: Enhanced with planning via `create_planning_beam_search()` 81 | 82 | ## Testing 83 | 84 | Run the comprehensive test suite: 85 | ```bash 86 | python test_planning_wrapper.py 87 | ``` 88 | 89 | This test validates: 90 | - Planning-enhanced versions of all supported algorithms 91 | - Proper budget allocation across approaches 92 | - Result aggregation and best approach selection 93 | - Fallback handling for plan parsing failures 94 | 95 | ## Implementation Details 96 | 97 | ### Plan Generation 98 | The wrapper generates plans using a structured prompt that encourages the model to think of 3 distinct mathematical approaches: 99 | 100 | ``` 101 | APPROACH 1: [Brief description of first method/strategy] 102 | APPROACH 2: [Brief description of second method/strategy] 103 | APPROACH 3: [Brief description of third method/strategy] 104 | ``` 105 | 106 | ### Budget Allocation 107 | - Planning phase uses 1 generation from the total budget 108 | - Remaining budget is divided equally across parsed approaches 109 | - Any remainder is distributed to the first few approaches 110 | 111 | ### Approach Selection 112 | The wrapper selects the best approach based on algorithm-specific scoring: 113 | - Tries various score attributes (`best_score`, `confidence`, `scores`, etc.) 114 | - Falls back to response length as a proxy for quality 115 | - Returns the approach with the highest score 116 | 117 | ### Error Handling 118 | - Robust plan parsing with regex patterns and fallbacks 119 | - Generic fallback approaches if parsing fails completely 120 | - Graceful handling of missing score attributes 121 | 122 | ## Performance Considerations 123 | 124 | - **Overhead**: 1 additional generation for planning 125 | - **Benefits**: Potentially better results through diverse approaches 126 | - **Trade-offs**: Lower budgets may suffer from planning overhead, higher budgets benefit more 127 | 128 | ## Future Enhancements 129 | 130 | - Adaptive planning based on problem complexity 131 | - Dynamic budget allocation based on approach confidence 132 | - Cross-approach result fusion techniques 133 | - Problem-specific approach templates -------------------------------------------------------------------------------- /its_hub/types.py: -------------------------------------------------------------------------------- 1 | """Type definitions for its_hub.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Literal 6 | 7 | from pydantic.dataclasses import dataclass 8 | 9 | 10 | @dataclass 11 | class Function: 12 | """Function definition for tool calls.""" 13 | 14 | name: str 15 | description: str | None = None 16 | parameters: dict | None = None 17 | 18 | 19 | @dataclass 20 | class ToolCall: 21 | """A tool call made by the assistant.""" 22 | 23 | id: str 24 | type: Literal["function"] = "function" 25 | function: Function | None = None 26 | 27 | 28 | @dataclass 29 | class ChatMessage: 30 | """A chat message with role and content. 31 | Content can be: 32 | - str: Simple text content 33 | - list[dict]: Multi-modal content (text, images, etc.) 34 | - None: No content (e.g., when using tool_calls) 35 | """ 36 | 37 | role: Literal["system", "user", "assistant", "tool"] 38 | content: str | list[dict] | None 39 | tool_calls: list[dict] | None = None # Store as plain dicts, not Pydantic objects 40 | tool_call_id: str | None = None 41 | 42 | def extract_text_content(self) -> str: 43 | """Extract text content from message, handling both string and list formats. 44 | For list content (multi-modal), extracts all text parts and warns about non-text content. 45 | Returns empty string if no text content is found. 46 | """ 47 | if self.content is None: 48 | return "" 49 | 50 | if isinstance(self.content, str): 51 | return self.content 52 | 53 | # Must be list[dict] at this point 54 | text_parts = [] 55 | has_image = False 56 | 57 | for item in self.content: 58 | content_type = item.get("type", "") 59 | 60 | if content_type == "text": 61 | text_parts.append(item.get("text", "")) 62 | elif content_type == "image_url": 63 | has_image = True 64 | elif content_type: 65 | raise ValueError( 66 | f"Unsupported content type '{content_type}' in messages content dict." 67 | ) 68 | 69 | if has_image: 70 | logging.warning( 71 | "Image content detected in message but is not supported. " 72 | "Image content will be ignored. Only text content is processed." 73 | ) 74 | 75 | return " ".join(text_parts) 76 | 77 | def to_dict(self) -> dict: 78 | """Convert ChatMessage to dictionary, excluding None values.""" 79 | result = {"role": self.role} 80 | if self.content is not None: 81 | result["content"] = self.content 82 | if self.tool_calls is not None: 83 | result["tool_calls"] = self.tool_calls 84 | if self.tool_call_id is not None: 85 | result["tool_call_id"] = self.tool_call_id 86 | return result 87 | 88 | 89 | class ChatMessages: 90 | """Unified wrapper for handling both string prompts and conversation history.""" 91 | 92 | def __init__(self, str_or_messages: str | list[ChatMessage]): 93 | self._str_or_messages = str_or_messages 94 | self._is_string = isinstance(str_or_messages, str) 95 | 96 | @classmethod 97 | def from_prompt_or_messages( 98 | cls, prompt_or_messages: str | list[ChatMessage] | ChatMessages 99 | ) -> ChatMessages: 100 | """Create ChatMessages from various input formats.""" 101 | if isinstance(prompt_or_messages, ChatMessages): 102 | return prompt_or_messages 103 | return cls(prompt_or_messages) 104 | 105 | def to_prompt(self) -> str: 106 | # TODO: chatMessage to string conversion will be deprecated in the future. 107 | """Convert to prompt string representation.""" 108 | if self._is_string: 109 | return self._str_or_messages 110 | 111 | lines = [] 112 | for msg in self._str_or_messages: 113 | text_content = msg.extract_text_content() 114 | 115 | if msg.role == "tool": 116 | # Tool messages: include tool_call_id context 117 | lines.append(f"tool[{msg.tool_call_id}]: {text_content}") 118 | elif msg.role == "assistant" and msg.tool_calls: 119 | # Assistant with tool calls: show tool calls + content if any 120 | tool_call_strs = [] 121 | for tc in msg.tool_calls: 122 | if tc.function: 123 | tool_call_strs.append(f"{tc.function.name}()") 124 | tool_calls_text = ", ".join(tool_call_strs) 125 | if text_content: 126 | lines.append( 127 | f"assistant: {text_content} [calls: {tool_calls_text}]" 128 | ) 129 | else: 130 | lines.append(f"assistant: [calls: {tool_calls_text}]") 131 | else: 132 | # Regular messages 133 | lines.append(f"{msg.role}: {text_content}") 134 | 135 | return "\n".join(lines) 136 | 137 | def to_chat_messages(self) -> list[ChatMessage]: 138 | """Convert to list of ChatMessage objects.""" 139 | if self._is_string: 140 | return [ChatMessage(role="user", content=self._str_or_messages)] 141 | return self._str_or_messages 142 | 143 | def to_batch(self, size: int) -> list[list[ChatMessage]]: 144 | """Create a batch of identical chat message lists for parallel generation.""" 145 | chat_messages = self.to_chat_messages() 146 | return [chat_messages for _ in range(size)] 147 | 148 | @property 149 | def is_string(self) -> bool: 150 | """Check if the original input was a string.""" 151 | return self._is_string 152 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Development Commands 6 | 7 | ### Installation and Setup 8 | ```bash 9 | # Development installation with uv (recommended) 10 | uv sync --extra dev 11 | 12 | # Alternative: pip installation 13 | pip install -e ".[dev]" 14 | 15 | # Production installation 16 | pip install its_hub 17 | ``` 18 | 19 | ### Contribution 20 | When commit or raising PR, never mention it is by ClaudeCode. 21 | never say 🤖 Generated with [Claude Code](https://claude.ai/code)" in the commit statment, don't mention claude! 22 | 23 | ### Testing 24 | ```bash 25 | # Run all tests 26 | uv run pytest tests/ 27 | 28 | # Run specific test file 29 | uv run pytest tests/test_algorithms.py 30 | 31 | # Run tests with coverage 32 | uv run pytest tests/ --cov=its_hub 33 | 34 | # Run tests with verbose output 35 | uv run pytest tests/ -v 36 | ``` 37 | 38 | ### Code Quality 39 | ```bash 40 | # Run linter checks 41 | uv run ruff check its_hub/ 42 | 43 | # Fix auto-fixable linting issues 44 | uv run ruff check its_hub/ --fix 45 | 46 | # Format code with ruff 47 | uv run ruff format its_hub/ 48 | ``` 49 | 50 | ### Git Workflow 51 | ```bash 52 | # Create commits with sign-off 53 | git commit -s -m "commit message" 54 | 55 | # For any git commits, always use the sign-off flag (-s) 56 | ``` 57 | 58 | ### Running Examples 59 | ```bash 60 | # Test basic functionality 61 | python scripts/test_math_example.py 62 | 63 | # Benchmark algorithms (see script help for full options) 64 | python scripts/benchmark.py --help 65 | ``` 66 | 67 | ### IaaS Service (Inference-as-a-Service) 68 | ```bash 69 | # Start IaaS service 70 | uv run its-iaas --host 0.0.0.0 --port 8108 71 | 72 | # Or using justfile (if available) 73 | just iaas-start 74 | 75 | # Check service health 76 | curl -s http://localhost:8108/v1/models | jq . 77 | 78 | # Configure the service (example: self-consistency algorithm) 79 | curl -X POST http://localhost:8108/configure \ 80 | -H "Content-Type: application/json" \ 81 | -d '{"endpoint": "http://localhost:8100/v1", "api_key": "NO_API_KEY", "model": "your-model-name", "alg": "self-consistency"}' 82 | 83 | # For comprehensive IaaS setup (multi-GPU, reward models, etc.), see docs/iaas-service.md 84 | ``` 85 | 86 | ## Additional Tips 87 | - Use `rg` in favor of `grep` whenever it's available 88 | - Use `uv` for Python environment management: always start with `uv sync --extra dev` to init the env and run stuff with `uv run` 89 | - In case of dependency issues during testing, try commenting out `reward_hub` and `vllm` temporarily in @pyproject.toml and retry. 90 | 91 | ## Architecture Overview 92 | 93 | **its_hub** is a library for inference-time scaling of LLMs, focusing on mathematical reasoning tasks. The core architecture uses abstract base classes to define clean interfaces between components. 94 | 95 | ### Key Base Classes (`its_hub/base.py`) 96 | - `AbstractLanguageModel`: Interface for LM generation and evaluation 97 | - `AbstractScalingAlgorithm`: Base for all scaling algorithms with unified `infer()` method 98 | - `AbstractScalingResult`: Base for algorithm results with `the_one` property 99 | - `AbstractOutcomeRewardModel`: Interface for outcome-based reward models 100 | - `AbstractProcessRewardModel`: Interface for process-based reward models (step-by-step scoring) 101 | 102 | ### Main Components 103 | 104 | #### Language Models (`its_hub/lms.py`) 105 | - `OpenAICompatibleLanguageModel`: Primary LM implementation supporting vLLM and OpenAI APIs 106 | - `StepGeneration`: Handles incremental generation with configurable step tokens and stop conditions 107 | - Supports async generation with concurrency limits and backoff strategies 108 | 109 | #### Algorithms (`its_hub/algorithms/`) 110 | All algorithms follow the same interface: `infer(lm, prompt, budget, return_response_only=True)` 111 | 112 | - **Self-Consistency**: Generate multiple responses, select most common answer 113 | - **Best-of-N**: Generate N responses, select highest scoring via outcome reward model 114 | - **Beam Search**: Step-by-step generation with beam width, uses process reward models 115 | - **Particle Filtering/Gibbs**: Probabilistic resampling with process reward models 116 | 117 | #### Integration (`its_hub/integration/`) 118 | - `LocalVllmProcessRewardModel`: Integrates with reward_hub library for process-based scoring 119 | - `iaas.py`: Inference-as-a-Service FastAPI server providing OpenAI-compatible chat completions API with budget parameter for inference-time scaling 120 | 121 | ### Budget Interpretation 122 | The budget parameter controls computational resources allocated to each algorithm. Different algorithms interpret budget as follows: 123 | - **Self-Consistency/Best-of-N**: Number of parallel generations to create 124 | - **Beam Search**: Total generations divided by beam width (controls search depth) 125 | - **Particle Filtering**: Number of particles maintained during sampling 126 | 127 | ### Step Generation Pattern 128 | The `StepGeneration` class enables incremental text generation: 129 | - Configure step tokens (e.g., "\n\n" for reasoning steps) 130 | - Set max steps and stop conditions 131 | - Post-processing for clean output formatting 132 | 133 | ### Typical Workflow 134 | 1. Start vLLM server with instruction model 135 | 2. Initialize `OpenAICompatibleLanguageModel` pointing to server 136 | 3. Create `StepGeneration` with step/stop tokens appropriate for the task 137 | 4. Initialize reward model (e.g., `LocalVllmProcessRewardModel`) 138 | 5. Create scaling algorithm with step generation and reward model 139 | 6. Call `infer()` with prompt and budget 140 | 141 | ### Mathematical Focus 142 | The library is optimized for mathematical reasoning: 143 | - Predefined system prompts in `its_hub/utils.py` (SAL_STEP_BY_STEP_SYSTEM_PROMPT, QWEN_SYSTEM_PROMPT) 144 | - Regex patterns for mathematical notation (e.g., `r"\boxed"` for final answers) 145 | - Integration with math_verify for evaluation 146 | - Benchmarking on MATH500 and AIME-2024 datasets 147 | 148 | ## Inference-as-a-Service (IaaS) 149 | 150 | The its_hub library includes an IaaS service that provides OpenAI-compatible API with inference-time scaling capabilities. For comprehensive setup instructions, usage examples, and troubleshooting, see [docs/iaas-service.md](./docs/iaas-service.md). -------------------------------------------------------------------------------- /its_hub/algorithms/beam_search.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | from pydantic.dataclasses import dataclass 5 | 6 | from its_hub.base import ( 7 | AbstractLanguageModel, 8 | AbstractProcessRewardModel, 9 | AbstractScalingAlgorithm, 10 | AbstractScalingResult, 11 | ) 12 | from its_hub.lms import StepGeneration 13 | from its_hub.types import ChatMessage, ChatMessages 14 | 15 | 16 | @dataclass 17 | class BeamSearchResult(AbstractScalingResult): 18 | responses: list[dict] # Keep original message format with tool calls 19 | scores: list[float] 20 | selected_index: int 21 | steps_used: list[int] 22 | 23 | @property 24 | def the_one(self) -> dict: 25 | return self.responses[self.selected_index] 26 | 27 | 28 | @dataclass 29 | class Path: 30 | steps: list[str] 31 | is_stopped: bool 32 | score: float 33 | 34 | def deepcopy(self): 35 | # create a deep copy of the path object 36 | return Path( 37 | steps=copy.deepcopy(self.steps), 38 | is_stopped=self.is_stopped, 39 | score=self.score, 40 | ) 41 | 42 | 43 | class BeamSearch(AbstractScalingAlgorithm): 44 | def __init__( 45 | self, 46 | sg: StepGeneration, 47 | prm: AbstractProcessRewardModel, 48 | beam_width: int, 49 | ): 50 | self.sg = sg 51 | self.prm = prm 52 | self.beam_width = beam_width 53 | 54 | async def _asearch_one_level( 55 | self, 56 | lm: AbstractLanguageModel, 57 | candidates: list[Path], 58 | prompt: str, 59 | tools: list[dict] | None = None, 60 | tool_choice: str | dict | None = None, 61 | ) -> list[Path]: 62 | """search one level asynchronously""" 63 | is_stopped_in_the_beginning = [c.is_stopped for c in candidates] 64 | 65 | # collect batch inputs 66 | prompts, steps_so_far = [], [] 67 | for c, is_stopped in zip(candidates, is_stopped_in_the_beginning): 68 | if is_stopped: 69 | continue 70 | prompts.append(prompt) 71 | steps_so_far.append(c.steps) 72 | 73 | # collect batch outputs 74 | sg_forward_results = await self.sg.aforward( 75 | lm, prompts, steps_so_far, tools=tools, tool_choice=tool_choice 76 | ) 77 | 78 | # update candidates 79 | i = 0 80 | for c, is_stopped in zip(candidates, is_stopped_in_the_beginning): 81 | if is_stopped: 82 | continue 83 | next_step, is_stopped = sg_forward_results[i] 84 | c.steps.append(next_step) 85 | c.is_stopped = is_stopped 86 | i += 1 87 | 88 | # collect batch inputs for scoring 89 | steps_so_far = [] 90 | for c, is_stopped in zip(candidates, is_stopped_in_the_beginning): 91 | if is_stopped: 92 | continue 93 | steps_so_far.append(c.steps) 94 | 95 | # collect batch outputs for scoring 96 | scores = await self.prm.ascore( 97 | prompt, 98 | [ 99 | self.sg._post_process(steps_so_far_per_prompt, stopped=True) 100 | for steps_so_far_per_prompt in steps_so_far 101 | ], 102 | ) 103 | 104 | # update candidates 105 | i = 0 106 | for c, is_stopped in zip(candidates, is_stopped_in_the_beginning): 107 | if is_stopped: 108 | continue 109 | c.score = scores[i] 110 | i += 1 111 | 112 | return candidates 113 | 114 | def _search_one_level( 115 | self, 116 | lm: AbstractLanguageModel, 117 | candidates: list[Path], 118 | prompt: str, 119 | tools: list[dict] | None = None, 120 | tool_choice: str | dict | None = None, 121 | ) -> list[Path]: 122 | """search one level synchronously""" 123 | import asyncio 124 | 125 | return asyncio.run( 126 | self._asearch_one_level(lm, candidates, prompt, tools, tool_choice) 127 | ) 128 | 129 | async def ainfer( 130 | self, 131 | lm: AbstractLanguageModel, 132 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 133 | budget: int, 134 | return_response_only: bool = True, 135 | tools: list[dict] | None = None, 136 | tool_choice: str | dict | None = None, 137 | ) -> dict | BeamSearchResult: 138 | """run inference asynchronously with beam search""" 139 | chat_messages = ChatMessages.from_prompt_or_messages(prompt_or_messages) 140 | assert budget % self.beam_width == 0, "budget must be divisible by beam_width" 141 | assert budget >= self.beam_width, ( 142 | "budget must be greater than or equal to beam_width" 143 | ) 144 | 145 | num_beams = budget // self.beam_width 146 | 147 | candidates = [ 148 | Path(steps=[], is_stopped=False, score=0) for _ in range(num_beams) 149 | ] 150 | 151 | while not all(c.is_stopped for c in candidates): 152 | # TODO: Update _asearch_one_level to support native ChatMessages format instead of string conversion 153 | candidates = await self._asearch_one_level( 154 | lm, 155 | candidates, 156 | chat_messages.to_prompt(), 157 | tools=tools, 158 | tool_choice=tool_choice, 159 | ) 160 | 161 | # get the top beam_width candidates 162 | candidates.sort(key=lambda x: x.score, reverse=True) 163 | candidates = candidates[: self.beam_width] 164 | 165 | # duplicate the candidates with the highest score 166 | new_candidates = [] 167 | for _ in range(num_beams): 168 | for c in candidates: 169 | new_candidates.append(c.deepcopy()) 170 | candidates = new_candidates 171 | 172 | scores = [c.score for c in candidates] 173 | steps_used = [len(c.steps) for c in candidates] 174 | result = BeamSearchResult( 175 | responses=[ 176 | { 177 | "role": "assistant", 178 | "content": self.sg._post_process(c.steps, stopped=True), 179 | } 180 | for c in candidates 181 | ], 182 | scores=scores, 183 | selected_index=int(np.argmax(scores)), 184 | steps_used=steps_used, 185 | ) 186 | return result.the_one if return_response_only else result 187 | -------------------------------------------------------------------------------- /its_hub/error_handling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Error handling utilities for OpenAI-compatible API calls. 3 | 4 | This module provides specific exception classes and utilities to handle 5 | different types of API failures with appropriate retry logic and informative 6 | error messages. 7 | """ 8 | 9 | import json 10 | from typing import Any 11 | 12 | 13 | class APIError(Exception): 14 | """Base class for API-related errors.""" 15 | 16 | def __init__( 17 | self, 18 | message: str, 19 | status_code: int | None = None, 20 | error_details: dict[str, Any] | None = None, 21 | ): 22 | self.message = message 23 | self.status_code = status_code 24 | self.error_details = error_details or {} 25 | super().__init__(message) 26 | 27 | 28 | class RateLimitError(APIError): 29 | """Rate limit exceeded - retryable.""" 30 | 31 | pass 32 | 33 | 34 | class ContextLengthError(APIError): 35 | """Context length exceeded - not retryable.""" 36 | 37 | pass 38 | 39 | 40 | class AuthenticationError(APIError): 41 | """Authentication failed - not retryable.""" 42 | 43 | pass 44 | 45 | 46 | class APIConnectionError(APIError): 47 | """Network/connection issues - retryable.""" 48 | 49 | pass 50 | 51 | 52 | class BadRequestError(APIError): 53 | """Bad request (invalid parameters) - not retryable.""" 54 | 55 | pass 56 | 57 | 58 | class InternalServerError(APIError): 59 | """Server error - retryable.""" 60 | 61 | pass 62 | 63 | 64 | # Retryable error types 65 | RETRYABLE_ERRORS = (RateLimitError, APIConnectionError, InternalServerError) 66 | 67 | 68 | def parse_api_error(status_code: int, error_text: str) -> APIError: 69 | """ 70 | Parse API error response and return appropriate exception. 71 | 72 | Args: 73 | status_code: HTTP status code 74 | error_text: Raw error response text 75 | 76 | Returns: 77 | Appropriate APIError subclass instance 78 | """ 79 | error_details = {} 80 | error_message = error_text 81 | 82 | # Try to parse JSON error response 83 | try: 84 | error_json = json.loads(error_text) 85 | if "error" in error_json: 86 | error_info = error_json["error"] 87 | if isinstance(error_info, dict): 88 | error_message = error_info.get("message", error_text) 89 | error_details = error_info 90 | else: 91 | error_message = str(error_info) 92 | except (json.JSONDecodeError, KeyError): 93 | # Use raw text if JSON parsing fails 94 | pass 95 | 96 | # Classify error based on status code and message content 97 | error_message_lower = error_message.lower() 98 | 99 | if status_code == 429 or "rate limit" in error_message_lower: 100 | return RateLimitError( 101 | f"Rate limit exceeded: {error_message}", 102 | status_code=status_code, 103 | error_details=error_details, 104 | ) 105 | 106 | if status_code == 400 and any( 107 | phrase in error_message_lower 108 | for phrase in [ 109 | "context length", 110 | "maximum context", 111 | "too long", 112 | "token limit", 113 | "context_length_exceeded", 114 | "max_tokens", 115 | ] 116 | ): 117 | return ContextLengthError( 118 | f"Context length exceeded: {error_message}", 119 | status_code=status_code, 120 | error_details=error_details, 121 | ) 122 | 123 | if ( 124 | status_code in [401, 403] 125 | or "authentication" in error_message_lower 126 | or "unauthorized" in error_message_lower 127 | ): 128 | return AuthenticationError( 129 | f"Authentication failed: {error_message}", 130 | status_code=status_code, 131 | error_details=error_details, 132 | ) 133 | 134 | if status_code == 400: 135 | return BadRequestError( 136 | f"Bad request: {error_message}", 137 | status_code=status_code, 138 | error_details=error_details, 139 | ) 140 | 141 | if status_code >= 500: 142 | return InternalServerError( 143 | f"Server error: {error_message}", 144 | status_code=status_code, 145 | error_details=error_details, 146 | ) 147 | 148 | # Default to APIConnectionError for other cases (network issues, etc.) 149 | return APIConnectionError( 150 | f"API request failed: {error_message}", 151 | status_code=status_code, 152 | error_details=error_details, 153 | ) 154 | 155 | 156 | def enhanced_on_backoff(details): 157 | """ 158 | Enhanced backoff callback that shows specific error information. 159 | 160 | Args: 161 | details: Backoff details dictionary containing exception info 162 | """ 163 | exception = details.get("exception") 164 | if isinstance(exception, APIError): 165 | error_type = type(exception).__name__ 166 | print( 167 | f"Retrying after {details['wait']:.1f}s (attempt {details['tries']}) - " 168 | f"{error_type}: {exception.message}" 169 | ) 170 | else: 171 | # Fallback for non-APIError exceptions 172 | print( 173 | f"Retrying after {details['wait']:.1f}s (attempt {details['tries']}) - " 174 | f"Error: {exception!s}" 175 | ) 176 | 177 | 178 | def should_retry(exception: Exception) -> bool: 179 | """ 180 | Determine if an exception should be retried. 181 | 182 | Args: 183 | exception: The exception to check 184 | 185 | Returns: 186 | True if the exception should be retried, False otherwise 187 | """ 188 | return isinstance(exception, RETRYABLE_ERRORS) 189 | 190 | 191 | def format_non_retryable_error(exception: APIError) -> str: 192 | """ 193 | Format a helpful message for non-retryable errors. 194 | 195 | Args: 196 | exception: The non-retryable APIError 197 | 198 | Returns: 199 | Formatted error message with suggestions 200 | """ 201 | if isinstance(exception, ContextLengthError): 202 | return ( 203 | f"❌ {exception.message}\n" 204 | f"💡 Suggestion: Reduce input length, increase max_tokens, or use a model with larger context window" 205 | ) 206 | 207 | if isinstance(exception, AuthenticationError): 208 | return ( 209 | f"❌ {exception.message}\n" 210 | f"💡 Suggestion: Check your API key and endpoint configuration" 211 | ) 212 | 213 | if isinstance(exception, BadRequestError): 214 | return ( 215 | f"❌ {exception.message}\n" 216 | f"💡 Suggestion: Check your request parameters (temperature, max_tokens, etc.)" 217 | ) 218 | 219 | return f"❌ {exception.message}" 220 | -------------------------------------------------------------------------------- /docs/quick-start.md: -------------------------------------------------------------------------------- 1 | # Quick Start Guide 2 | 3 | This guide shows examples of inference-time scaling. **Tool calling is the primary use case** for production applications. 4 | 5 | ## Example 1: Self-Consistency with Tool Calling (Recommended) 6 | 7 | **Installation required:** `pip install its_hub` 8 | 9 | This example shows how to use Self-Consistency for reliable tool calling in agent applications. 10 | 11 | ```python 12 | from its_hub.lms import OpenAICompatibleLanguageModel 13 | from its_hub.algorithms import SelfConsistency 14 | from its_hub.types import ChatMessage, ChatMessages 15 | 16 | # Initialize language model 17 | lm = OpenAICompatibleLanguageModel( 18 | endpoint="https://api.openai.com/v1", 19 | api_key="your-api-key", 20 | model_name="gpt-4o-mini" 21 | ) 22 | 23 | # Define tools (OpenAI format) 24 | tools = [ 25 | { 26 | "type": "function", 27 | "function": { 28 | "name": "calculator", 29 | "description": "Perform arithmetic calculations", 30 | "parameters": { 31 | "type": "object", 32 | "properties": { 33 | "expression": { 34 | "type": "string", 35 | "description": "Mathematical expression to evaluate" 36 | } 37 | }, 38 | "required": ["expression"] 39 | } 40 | } 41 | } 42 | ] 43 | 44 | # Create messages 45 | messages = ChatMessages([ 46 | ChatMessage( 47 | role="system", 48 | content="You are a precise calculator. Always use the calculator tool for arithmetic." 49 | ), 50 | ChatMessage( 51 | role="user", 52 | content="What is 847 * 293 + 156?" 53 | ) 54 | ]) 55 | 56 | # Use hierarchical tool voting 57 | sc = SelfConsistency(tool_vote="tool_hierarchical") 58 | result = sc.infer( 59 | lm, 60 | messages, 61 | budget=5, 62 | tools=tools, 63 | tool_choice="auto" 64 | ) 65 | print(result) 66 | ``` 67 | 68 | **What happens:** 69 | 1. Generates 5 different responses with tool calls 70 | 2. Votes on tool names first (which tool to use) 71 | 3. Then votes on tool arguments (what parameters to pass) 72 | 4. Returns the most consistent tool call 73 | 74 | --- 75 | 76 | ## Example 2: Best-of-N with LLM Judge (Core Installation) 77 | 78 | **Installation required:** `pip install its_hub` 79 | 80 | This example uses Best-of-N algorithm with an LLM judge for response selection. Works with any OpenAI-compatible API and requires no GPU. 81 | 82 | ```python 83 | from its_hub.lms import OpenAICompatibleLanguageModel 84 | from its_hub.algorithms import BestOfN 85 | from its_hub.integration.reward_hub import LLMJudgeRewardModel 86 | 87 | # Initialize language model 88 | lm = OpenAICompatibleLanguageModel( 89 | endpoint="https://api.openai.com/v1", 90 | api_key="your-api-key", 91 | model_name="gpt-4o-mini", 92 | ) 93 | 94 | # Set up LLM judge for scoring 95 | judge = LLMJudgeRewardModel( 96 | model="gpt-4o-mini", 97 | criterion="overall_quality", 98 | judge_type="groupwise", 99 | api_key="your-api-key", 100 | ) 101 | scaling_alg = BestOfN(judge) 102 | 103 | # Generate multiple responses and select the best 104 | result = scaling_alg.infer( 105 | lm, 106 | "Explain quantum entanglement in simple terms", 107 | budget=4 108 | ) 109 | print(result) 110 | ``` 111 | 112 | **What happens:** 113 | 1. Generates 4 different responses to the prompt 114 | 2. LLM judge scores all responses 115 | 3. Returns the highest-scoring response 116 | 117 | --- 118 | 119 | ## Example 3: Particle Filtering with Process Reward Model 120 | 121 | **Installation required:** `pip install its_hub[prm]` 122 | 123 | This example uses Particle Filtering for step-by-step mathematical reasoning with a local process reward model. Requires GPU. 124 | 125 | ### Prerequisites 126 | 127 | - GPU with CUDA 11.8+ 128 | - 20GB+ GPU memory recommended (for 7B reward model) 129 | 130 | ### Step 1: Start vLLM Server 131 | 132 | Start a vLLM server with your instruction model: 133 | 134 | ```bash 135 | CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-Math-1.5B-Instruct \ 136 | --dtype float16 \ 137 | --port 8100 \ 138 | --max-model-len 4096 \ 139 | --gpu-memory-utilization 0.7 140 | ``` 141 | 142 | ### Step 2: Run Particle Filtering 143 | 144 | ```python 145 | from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT 146 | from its_hub.lms import OpenAICompatibleLanguageModel, StepGeneration 147 | from its_hub.algorithms import ParticleFiltering 148 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 149 | 150 | # Initialize language model (points to vLLM server) 151 | lm = OpenAICompatibleLanguageModel( 152 | endpoint="http://localhost:8100/v1", 153 | api_key="NO_API_KEY", 154 | model_name="Qwen/Qwen2.5-Math-1.5B-Instruct", 155 | system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT, 156 | ) 157 | 158 | # Set up step generation and process reward model 159 | sg = StepGeneration(step_token="\n\n", max_steps=32, stop_token=r"\boxed") 160 | prm = LocalVllmProcessRewardModel( 161 | model_name="Qwen/Qwen2.5-Math-PRM-7B", 162 | device="cuda:0", 163 | aggregation_method="prod" 164 | ) 165 | scaling_alg = ParticleFiltering(sg, prm) 166 | 167 | # Solve with step-by-step reasoning 168 | result = scaling_alg.infer(lm, "Solve x^2 + 5x + 6 = 0", budget=8) 169 | print(result) 170 | ``` 171 | 172 | **What happens:** 173 | 1. Generates reasoning steps incrementally (separated by `\n\n`) 174 | 2. Process reward model scores each step 175 | 3. Maintains 8 particles, resampling based on scores 176 | 4. Returns best solution when `\boxed` pattern is found 177 | 178 | ### Memory Requirements 179 | 180 | - **Instruction Model** (Qwen2.5-Math-1.5B): ~3GB GPU memory 181 | - **Reward Model** (Qwen2.5-Math-PRM-7B): ~14GB GPU memory 182 | - **Total**: 20GB+ recommended (H100, A100, or similar) 183 | 184 | --- 185 | 186 | ## Troubleshooting 187 | 188 | ### Examples 1 & 2 (Cloud APIs) 189 | 190 | **API errors**: Verify API key and endpoint are correct 191 | **Slow responses**: Reduce `budget` parameter (e.g., from 5 to 2) 192 | 193 | ### Example 3 (Particle Filtering) 194 | 195 | **CUDA Out of Memory**: 196 | - Reduce `--gpu-memory-utilization` to 0.6 or lower 197 | - Reduce `--max-num-seqs` to 64 198 | - Ensure no other processes are using GPU 199 | - Check memory: `nvidia-smi` 200 | 201 | **Server Connection Issues**: 202 | ```bash 203 | # Verify vLLM server is running 204 | curl http://localhost:8100/v1/models 205 | ``` 206 | 207 | **Model Loading Issues**: 208 | - Ensure sufficient disk space (~20GB for both models) 209 | - Check internet connection for model downloads 210 | - Verify model names are correct 211 | 212 | --- 213 | 214 | ## Next Steps 215 | 216 | - **Explore algorithms**: See [Algorithms](algorithms.md) for Beam Search, Self-Consistency, and other approaches 217 | - **Deploy as API**: See [IaaS Service Guide](iaas-service.md) to deploy as OpenAI-compatible service 218 | - **Contribute**: See [Development](development.md) for contribution guidelines -------------------------------------------------------------------------------- /docs/benchmarking.md: -------------------------------------------------------------------------------- 1 | # Benchmarking 2 | 3 | its-hub includes comprehensive benchmarking tools to evaluate inference-time scaling algorithms on standard mathematical reasoning datasets. 4 | 5 | ## Quick Start 6 | 7 | ```bash 8 | python scripts/benchmark.py --help 9 | ``` 10 | 11 | Example benchmark command: 12 | ```bash 13 | python scripts/benchmark.py \ 14 | --benchmark aime-2024 \ 15 | --model_name Qwen/Qwen2.5-Math-1.5B-Instruct \ 16 | --alg particle-filtering \ 17 | --rm_device cuda:1 \ 18 | --endpoint http://0.0.0.0:8000/v1 \ 19 | --shuffle_seed 1110 \ 20 | --does_eval \ 21 | --budgets 1,2,4,8,16,32,64 \ 22 | --rm_agg_method model 23 | ``` 24 | 25 | ## Supported Datasets 26 | 27 | ### MATH500 28 | A subset of 500 problems from the MATH dataset, covering various mathematical topics. 29 | 30 | ```bash 31 | python scripts/benchmark.py --benchmark math-500 --model_name Qwen/Qwen2.5-Math-1.5B-Instruct 32 | ``` 33 | 34 | ### AIME-2024 35 | American Invitational Mathematics Examination problems from 2024. 36 | 37 | ```bash 38 | python scripts/benchmark.py --benchmark aime-2024 --model_name Qwen/Qwen2.5-Math-1.5B-Instruct 39 | ``` 40 | 41 | ## Algorithm Comparison 42 | 43 | ### Benchmarking Multiple Algorithms 44 | 45 | ```bash 46 | # Compare all algorithms on MATH500 47 | for alg in self-consistency best-of-n beam-search particle-filtering; do 48 | python scripts/benchmark.py \ 49 | --benchmark math-500 \ 50 | --model_name Qwen/Qwen2.5-Math-1.5B-Instruct \ 51 | --alg $alg \ 52 | --budgets 1,2,4,8,16 \ 53 | --does_eval 54 | done 55 | ``` 56 | 57 | ### Budget Scaling Analysis 58 | 59 | ```bash 60 | # Analyze performance vs computational budget 61 | python scripts/benchmark.py \ 62 | --benchmark math-500 \ 63 | --model_name Qwen/Qwen2.5-Math-1.5B-Instruct \ 64 | --alg particle-filtering \ 65 | --budgets 1,2,4,8,16,32,64,128 \ 66 | --does_eval 67 | ``` 68 | 69 | ## Configuration Options 70 | 71 | ### Basic Parameters 72 | 73 | - `--benchmark`: Dataset to use (`math-500`, `aime-2024`) 74 | - `--model_name`: Model identifier (e.g., `Qwen/Qwen2.5-Math-1.5B-Instruct`) 75 | - `--alg`: Algorithm to benchmark (`self-consistency`, `best-of-n`, `beam-search`, `particle-filtering`) 76 | - `--budgets`: Comma-separated list of budget values 77 | - `--endpoint`: API endpoint for model inference 78 | - `--does_eval`: Enable automatic evaluation of results 79 | 80 | ### Advanced Parameters 81 | 82 | - `--shuffle_seed`: Seed for shuffling problems (reproducibility) 83 | - `--rm_device`: GPU device for reward model (e.g., `cuda:0`) 84 | - `--rm_agg_method`: Reward aggregation method (`prod`, `mean`, `model`) 85 | - `--beam_width`: Beam width for beam search (default: 4) 86 | - `--max_steps`: Maximum steps for step-by-step algorithms 87 | - `--step_token`: Token for step boundaries (default: `\\n\\n`) 88 | - `--stop_pattern`: Regex pattern for stopping generation 89 | 90 | ## Output Format 91 | 92 | ### Results Structure 93 | 94 | The benchmark script generates detailed results including: 95 | 96 | ```json 97 | { 98 | "algorithm": "particle-filtering", 99 | "dataset": "math-500", 100 | "model": "Qwen/Qwen2.5-Math-1.5B-Instruct", 101 | "budget": 8, 102 | "accuracy": 0.756, 103 | "total_problems": 500, 104 | "correct_answers": 378, 105 | "average_response_time": 12.34, 106 | "detailed_results": [ 107 | { 108 | "problem_id": "001", 109 | "problem": "Solve x^2 + 5x + 6 = 0", 110 | "correct_answer": "x = -2, -3", 111 | "model_response": "...", 112 | "is_correct": true, 113 | "response_time": 8.21 114 | } 115 | ] 116 | } 117 | ``` 118 | 119 | ### Evaluation Metrics 120 | 121 | - **Accuracy**: Percentage of correctly solved problems 122 | - **Response Time**: Average time per problem (seconds) 123 | - **Budget Efficiency**: Accuracy improvement per unit budget 124 | - **Error Analysis**: Breakdown of error types and frequencies 125 | 126 | ## Performance Analysis 127 | 128 | ### Plotting Results 129 | 130 | ```python 131 | import matplotlib.pyplot as plt 132 | import json 133 | 134 | # Load benchmark results 135 | with open('benchmark_results.json', 'r') as f: 136 | results = json.load(f) 137 | 138 | # Plot accuracy vs budget 139 | budgets = [r['budget'] for r in results] 140 | accuracies = [r['accuracy'] for r in results] 141 | 142 | plt.figure(figsize=(10, 6)) 143 | plt.plot(budgets, accuracies, 'o-') 144 | plt.xlabel('Budget') 145 | plt.ylabel('Accuracy') 146 | plt.title('Accuracy vs Computational Budget') 147 | plt.grid(True) 148 | plt.show() 149 | ``` 150 | 151 | ### Statistical Analysis 152 | 153 | ```python 154 | import numpy as np 155 | from scipy import stats 156 | 157 | # Compare two algorithms 158 | results_a = [r for r in results if r['algorithm'] == 'self-consistency'] 159 | results_b = [r for r in results if r['algorithm'] == 'particle-filtering'] 160 | 161 | accuracies_a = [r['accuracy'] for r in results_a] 162 | accuracies_b = [r['accuracy'] for r in results_b] 163 | 164 | # Perform t-test 165 | t_stat, p_value = stats.ttest_ind(accuracies_a, accuracies_b) 166 | print(f"T-test p-value: {p_value}") 167 | ``` 168 | 169 | ## Custom Benchmarks 170 | 171 | ### Adding New Datasets 172 | 173 | ```python 174 | # Create custom dataset 175 | custom_problems = [ 176 | { 177 | "id": "custom_001", 178 | "problem": "Your math problem here", 179 | "answer": "Expected answer", 180 | "category": "algebra" 181 | } 182 | ] 183 | 184 | # Save as JSON 185 | import json 186 | with open('custom_benchmark.json', 'w') as f: 187 | json.dump(custom_problems, f) 188 | ``` 189 | 190 | ### Custom Evaluation Metrics 191 | 192 | ```python 193 | def custom_evaluator(predicted_answer, correct_answer): 194 | """Custom evaluation function""" 195 | # Implement your evaluation logic 196 | return predicted_answer.strip().lower() == correct_answer.strip().lower() 197 | 198 | # Use in benchmark script 199 | python scripts/benchmark.py \ 200 | --benchmark custom_benchmark.json \ 201 | --custom_evaluator custom_evaluator 202 | ``` 203 | 204 | ## Best Practices 205 | 206 | ### Reproducibility 207 | 208 | 1. **Set Random Seeds**: Use `--shuffle_seed` for consistent problem ordering 209 | 2. **Fixed Hyperparameters**: Document all configuration options 210 | 3. **Environment Tracking**: Record GPU type, driver versions, and dependencies 211 | 212 | ### Performance Optimization 213 | 214 | 1. **GPU Memory Management**: Monitor memory usage during benchmarks 215 | 2. **Batch Processing**: Use appropriate batch sizes for your hardware 216 | 3. **Caching**: Enable model caching for faster repeated evaluations 217 | 218 | ### Result Validation 219 | 220 | 1. **Cross-Validation**: Run multiple seeds and average results 221 | 2. **Significance Testing**: Use statistical tests to validate improvements 222 | 3. **Human Evaluation**: Manually verify a sample of results 223 | 224 | ## Troubleshooting 225 | 226 | ### Common Issues 227 | 228 | **Out of Memory Errors:** 229 | ```bash 230 | # Reduce batch size or budget 231 | python scripts/benchmark.py --budgets 1,2,4 --rm_device cuda:0 232 | ``` 233 | 234 | **Slow Evaluation:** 235 | ```bash 236 | # Disable evaluation for faster benchmarking 237 | python scripts/benchmark.py --no_eval 238 | ``` 239 | 240 | **Model Loading Issues:** 241 | ```bash 242 | # Verify model availability 243 | curl http://localhost:8000/v1/models 244 | ``` 245 | 246 | ### Performance Monitoring 247 | 248 | ```bash 249 | # Monitor GPU usage during benchmarking 250 | watch -n 1 nvidia-smi 251 | 252 | # Monitor system resources 253 | htop 254 | ``` -------------------------------------------------------------------------------- /tests/test_particle_gibbs_resampling.py: -------------------------------------------------------------------------------- 1 | """Test for the particle Gibbs resampling weight calculation fix (issue #54).""" 2 | 3 | import random 4 | 5 | import numpy as np 6 | 7 | from its_hub.algorithms.particle_gibbs import ( 8 | Particle, 9 | ParticleGibbs, 10 | ParticleGibbsResult, 11 | SelectionMethod, 12 | ) 13 | from its_hub.base import AbstractLanguageModel, AbstractProcessRewardModel 14 | from its_hub.lms import StepGeneration 15 | 16 | 17 | class MockLanguageModelForResampling(AbstractLanguageModel): 18 | """Mock LM that generates predictable steps for testing resampling.""" 19 | 20 | def __init__(self): 21 | self.step_counter = 0 22 | 23 | async def agenerate(self, messages, **kwargs): 24 | return self.generate(messages, **kwargs) 25 | 26 | def generate(self, messages, max_tokens=100, **kwargs): 27 | # Handle both single and batch calls like OpenAICompatibleLanguageModel 28 | if ( 29 | isinstance(messages, list) 30 | and len(messages) > 0 31 | and isinstance(messages[0], list) 32 | ): 33 | # Batch generation 34 | results = [] 35 | for _ in messages: 36 | step = f"step{self.step_counter}" 37 | self.step_counter += 1 38 | results.append({"role": "assistant", "content": step}) 39 | return results 40 | else: 41 | # Single generation 42 | step = f"step{self.step_counter}" 43 | self.step_counter += 1 44 | return {"role": "assistant", "content": step} 45 | 46 | async def aevaluate(self, prompt, response): 47 | return self.evaluate(prompt, response) 48 | 49 | def evaluate(self, prompt, response): 50 | # Not used in these tests 51 | return 0.5 52 | 53 | 54 | class MockProcessRewardModelForResampling(AbstractProcessRewardModel): 55 | """Mock PRM that gives higher scores to longer sequences.""" 56 | 57 | async def ascore(self, prompt, response): 58 | return self.score(prompt, response) 59 | 60 | def score(self, prompt, response): 61 | if isinstance(response, list): 62 | # Batch scoring 63 | return [self._score_single(r) for r in response] 64 | else: 65 | # Single scoring 66 | return self._score_single(response) 67 | 68 | def _score_single(self, response): 69 | # Give higher scores to longer responses 70 | # This simulates a scenario where reference particles (being longer) 71 | # would have unfairly high scores if we don't use partial weights 72 | num_steps = response.count("step") 73 | # Return a score between 0.5 and 0.9 based on length 74 | return min(0.5 + 0.1 * num_steps, 0.9) 75 | 76 | 77 | class TestParticleGibbsResampling: 78 | """Test the fix for issue #54 - proper calculation of resampling weights.""" 79 | 80 | def test_reference_trajectory_partial_weights(self): 81 | """Test that reference trajectories use partial weights during resampling.""" 82 | # Create mock models 83 | mock_lm = MockLanguageModelForResampling() 84 | mock_prm = MockProcessRewardModelForResampling() 85 | 86 | # Create step generation with 3 max steps 87 | sg = StepGeneration(step_token="\n", max_steps=3) 88 | 89 | # Create ParticleGibbs with 2 iterations and 1 reference particle 90 | pg = ParticleGibbs( 91 | sg=sg, 92 | prm=mock_prm, 93 | num_iterations=2, 94 | selection_method=SelectionMethod.ARGMAX, 95 | num_ref_particles=1, 96 | ) 97 | 98 | # Run inference with budget=4 (2 particles per iteration) 99 | result = pg.infer(mock_lm, "Test prompt", budget=4, return_response_only=False) 100 | 101 | # Verify the result structure 102 | assert isinstance(result, ParticleGibbsResult) 103 | assert len(result.responses_lst) == 2 # 2 iterations 104 | assert len(result.log_weights_lst) == 2 105 | assert len(result.ref_indices_lst) == 2 106 | 107 | # In the second iteration, check that particles have consistent trajectory lengths 108 | # after resampling (this would fail with the old implementation) 109 | second_iter_steps = result.steps_used_lst[1] 110 | 111 | # All particles should have progressed similarly 112 | # With the bug, reference particle would keep its full trajectory 113 | # making it have more steps than others 114 | assert len(set(second_iter_steps)) <= 2, ( 115 | f"Particles have very different step counts: {second_iter_steps}. " 116 | "This suggests reference particles weren't truncated properly." 117 | ) 118 | 119 | def test_resampling_weights_consistency(self): 120 | """Test that resampling uses consistent weights across particles.""" 121 | # Create a custom mock that allows us to inspect the resampling process 122 | mock_lm = MockLanguageModelForResampling() 123 | mock_prm = MockProcessRewardModelForResampling() 124 | 125 | sg = StepGeneration(step_token="\n", max_steps=5) 126 | 127 | # Create a test scenario with manual particle manipulation 128 | # Initialize particles with different trajectories 129 | ref_particle = Particle( 130 | steps=["ref_step1", "ref_step2", "ref_step3"], 131 | is_stopped=True, 132 | partial_log_weights=[0.6, 0.7, 0.8], 133 | ) 134 | 135 | new_particle = Particle( 136 | steps=["new_step1"], is_stopped=False, partial_log_weights=[0.65] 137 | ) 138 | 139 | # Test the _propagate method 140 | pg = ParticleGibbs( 141 | sg=sg, 142 | prm=mock_prm, 143 | num_iterations=1, 144 | selection_method=SelectionMethod.ARGMAX, 145 | num_ref_particles=1, 146 | ) 147 | 148 | # Propagate one step 149 | particles = [new_particle, ref_particle] 150 | propagated = pg._propagate(mock_lm, particles, "Test prompt") 151 | 152 | # After propagation, only non-stopped particles get extended 153 | assert ( 154 | len(propagated[0].partial_log_weights) == 2 155 | ) # new particle now has 2 steps 156 | assert ( 157 | len(propagated[1].partial_log_weights) == 3 158 | ) # ref particle stays at 3 steps (stopped) 159 | 160 | # The key insight: during resampling in the main loop, 161 | # we should compare partial_log_weights[1] for both particles, 162 | # not the full log_weight of the reference particle 163 | 164 | def test_reference_particle_truncation(self): 165 | """Test that reference particles are properly truncated when resampled.""" 166 | mock_lm = MockLanguageModelForResampling() 167 | mock_prm = MockProcessRewardModelForResampling() 168 | 169 | sg = StepGeneration(step_token="\n", max_steps=4) 170 | 171 | # Use a controlled random seed for reproducibility 172 | random.seed(42) 173 | np.random.seed(42) 174 | 175 | pg = ParticleGibbs( 176 | sg=sg, 177 | prm=mock_prm, 178 | num_iterations=2, 179 | selection_method=SelectionMethod.ARGMAX, 180 | num_ref_particles=1, 181 | ) 182 | 183 | result = pg.infer(mock_lm, "Test prompt", budget=6, return_response_only=False) 184 | 185 | # Check that in the second iteration, particles don't have 186 | # drastically different numbers of steps 187 | second_iter_steps = result.steps_used_lst[1] 188 | 189 | # The maximum difference in steps should be reasonable 190 | # (not the full trajectory length difference) 191 | max_diff = max(second_iter_steps) - min(second_iter_steps) 192 | assert max_diff <= 2, ( 193 | f"Large step difference in second iteration: {second_iter_steps}. " 194 | "Reference particle may not have been truncated." 195 | ) 196 | -------------------------------------------------------------------------------- /notebooks/self-consistency.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "98e1f27d", 6 | "metadata": {}, 7 | "source": [ 8 | "# Self-Consistency Algorithm Demo\n", 9 | "This notebook demonstrates the Self-Consistency algorithm for mathematical reasoning." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "id": "be6eaec0", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%load_ext autoreload\n", 20 | "%autoreload 2" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "b6e4fabc", 27 | "metadata": { 28 | "lines_to_next_cell": 0 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "import os\n", 33 | "\n", 34 | "import nest_asyncio\n", 35 | "from dotenv import load_dotenv\n", 36 | "\n", 37 | "from its_hub.lms import OpenAICompatibleLanguageModel\n", 38 | "from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT\n", 39 | "\n", 40 | "nest_asyncio.apply()\n", 41 | "\n", 42 | "# Load environment variables from .env file\n", 43 | "load_dotenv()\n", 44 | "\n", 45 | "# Main example: OpenAI API endpoint with gpt-4o-mini\n", 46 | "lm = OpenAICompatibleLanguageModel(\n", 47 | " endpoint=\"https://api.openai.com/v1\",\n", 48 | " api_key=os.getenv(\"OPENAI_API_KEY\"), # Load API key from environment\n", 49 | " model_name=\"gpt-4o-mini\",\n", 50 | " system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT,\n", 51 | " is_async=True,\n", 52 | ")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "ed10ced9", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# Alternative: vLLM local endpoint (commented out)\n", 63 | "# lm = OpenAICompatibleLanguageModel(\n", 64 | "# endpoint=\"http://localhost:8000/v1\",\n", 65 | "# api_key=\"NO_API_KEY\",\n", 66 | "# model_name=\"qwen2-math-1.5b-instruct\",\n", 67 | "# system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT,\n", 68 | "# is_async=True,\n", 69 | "# )" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "35d3f4a5", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Mathematical problem to solve\n", 80 | "prompt = r\"Let $a$ be a positive real number such that all the roots of \\[x^3 + ax^2 + ax + 1 = 0\\]are real. Find the smallest possible value of $a.$\"\n", 81 | "\n", 82 | "# Generate response using the proper format\n", 83 | "from its_hub.types import ChatMessages\n", 84 | "\n", 85 | "chat_messages = ChatMessages.from_prompt_or_messages(prompt)\n", 86 | "response = lm.generate(chat_messages.to_batch(1))[0]\n", 87 | "\n", 88 | "print(response)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "53386b1e", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "def extract_boxed(s: str) -> str:\n", 99 | " import re\n", 100 | " # find all occurrences of \\boxed{...}\n", 101 | " boxed_matches = re.findall(r'\\\\boxed\\{([^{}]+(?:\\{[^{}]*\\}[^{}]*)*)\\}', s)\n", 102 | " # return the last match if any were found\n", 103 | " return boxed_matches[-1] if boxed_matches else \"\"\n", 104 | "\n", 105 | "print(extract_boxed(response['content']))" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "id": "7bd4a72e", 111 | "metadata": {}, 112 | "source": [ 113 | "## Self-Consistency Algorithm\n", 114 | "Now we'll use the Self-Consistency algorithm to improve the answer quality." 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "1a3ab056", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "from its_hub.algorithms import SelfConsistency\n", 125 | "\n", 126 | "# Set computational budget for scaling\n", 127 | "budget = 4\n", 128 | "\n", 129 | "scaling_alg = SelfConsistency(extract_boxed)\n", 130 | "\n", 131 | "scaling_result = scaling_alg.infer(\n", 132 | " lm, prompt, budget, return_response_only=False\n", 133 | ")\n", 134 | "\n", 135 | "print(\"######## Self-Consistency Result ########\")\n", 136 | "print(scaling_result.the_one)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "a12470e4", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "print(\"######## Extracted Response Counts ########\")\n", 147 | "print(scaling_result.response_counts)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "1b960ab8", 154 | "metadata": { 155 | "lines_to_next_cell": 2 156 | }, 157 | "outputs": [], 158 | "source": [] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "97d69191", 163 | "metadata": {}, 164 | "source": [ 165 | "## Self-Consistency Algorithm for Tool Calls\n", 166 | "We have hierarchical tool-voting support in Self-Consistency algorithm\n", 167 | "It first votes on tool names, and then on tool arguments." 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "b02c7de2", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "from its_hub.types import ChatMessage, ChatMessages\n", 178 | "\n", 179 | "# Tool schema (OpenAI-style dicts)\n", 180 | "tools = [\n", 181 | " {\n", 182 | " \"type\": \"function\",\n", 183 | " \"function\": {\n", 184 | " \"name\": \"calculator\",\n", 185 | " \"description\": \"Perform arithmetic calculations\",\n", 186 | " \"parameters\": {\n", 187 | " \"type\": \"object\",\n", 188 | " \"properties\": {\n", 189 | " \"expression\": {\n", 190 | " \"type\": \"string\",\n", 191 | " \"description\": \"Mathematical expression to evaluate\"\n", 192 | " }\n", 193 | " },\n", 194 | " \"required\": [\"expression\"]\n", 195 | " }\n", 196 | " }\n", 197 | " }\n", 198 | "]\n", 199 | "\n", 200 | "# ChatMessages instance with system + user\n", 201 | "tool_call_messages = ChatMessages([\n", 202 | " ChatMessage(\n", 203 | " role=\"system\",\n", 204 | " content=\"You are a precise calculator. Always use the calculator tool for arithmetic and format your final answer as \\\\boxed{result}.\"\n", 205 | " ),\n", 206 | " ChatMessage(\n", 207 | " role=\"user\",\n", 208 | " content=\"What is 847 * 293 + 156?\"\n", 209 | " ),\n", 210 | "])" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "id": "5be77c52", 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "# Use hierarchical tool voting\n", 221 | "scaling_alg_tool = SelfConsistency(tool_vote=\"tool_hierarchical\")\n", 222 | "\n", 223 | "budget = 5\n", 224 | "scaling_result = scaling_alg_tool.infer(\n", 225 | " lm, tool_call_messages, budget, return_response_only=False, tools=tools, tool_choice=\"auto\"\n", 226 | ")" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "ac526d23", 233 | "metadata": { 234 | "lines_to_next_cell": 2 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "print(\"######## Self-Consistency Result ########\")\n", 239 | "print(scaling_result.the_one)\n", 240 | "\n", 241 | "print(\"######## Tool Call Response Counts ########\")\n", 242 | "print(scaling_result.response_counts)" 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "jupytext": { 248 | "cell_metadata_filter": "-all", 249 | "notebook_metadata_filter": "all" 250 | }, 251 | "kernelspec": { 252 | "display_name": "inference_time_scaling-dev", 253 | "language": "python", 254 | "name": "python3" 255 | }, 256 | "language_info": { 257 | "codemirror_mode": { 258 | "name": "ipython", 259 | "version": 3 260 | }, 261 | "file_extension": ".py", 262 | "mimetype": "text/x-python", 263 | "name": "python", 264 | "nbconvert_exporter": "python", 265 | "pygments_lexer": "ipython3", 266 | "version": "3.11.11" 267 | } 268 | }, 269 | "nbformat": 4, 270 | "nbformat_minor": 5 271 | } 272 | -------------------------------------------------------------------------------- /docs/algorithms.md: -------------------------------------------------------------------------------- 1 | # Algorithms 2 | 3 | its-hub provides several inference-time scaling algorithms, each optimized for different use cases and computational budgets. 4 | 5 | ## Overview 6 | 7 | All algorithms follow the same interface: `infer(lm, prompt, budget, return_response_only=True)` 8 | 9 | The `budget` parameter controls computational resources allocated to each algorithm, with different interpretations: 10 | 11 | | Algorithm | Budget Interpretation | Snippet | 12 | |-----------|----------------------|---------| 13 | | Self-Consistency | Number of parallel generations | `SelfConsistency()` | 14 | | Best-of-N | Number of candidate responses | `BestOfN(rm)` | 15 | | Beam Search | Total generations ÷ beam width | `BeamSearch(sg, prm, beam_width=4)` | 16 | | Particle Filtering | Number of particles | `ParticleFiltering(sg, prm)` | 17 | | Entropic Particle Filtering | Number of particles | `EntropicParticleFiltering(sg, prm)` | 18 | 19 | ## Self-Consistency 20 | 21 | Generates multiple responses and selects the most common answer through voting. **Especially powerful for tool-calling** where you want consistent tool usage patterns. 22 | 23 | ### Tool Calling Example (Recommended) 24 | 25 | ```python 26 | from its_hub.algorithms import SelfConsistency 27 | from its_hub.types import ChatMessage, ChatMessages 28 | from its_hub.lms import OpenAICompatibleLanguageModel 29 | 30 | # Initialize language model 31 | lm = OpenAICompatibleLanguageModel( 32 | endpoint="https://api.openai.com/v1", 33 | api_key="your-api-key", 34 | model_name="gpt-4o-mini" 35 | ) 36 | 37 | # Define tools (OpenAI format) 38 | tools = [ 39 | { 40 | "type": "function", 41 | "function": { 42 | "name": "calculator", 43 | "description": "Perform arithmetic calculations", 44 | "parameters": { 45 | "type": "object", 46 | "properties": { 47 | "expression": { 48 | "type": "string", 49 | "description": "Mathematical expression to evaluate" 50 | } 51 | }, 52 | "required": ["expression"] 53 | } 54 | } 55 | } 56 | ] 57 | 58 | # Create messages 59 | messages = ChatMessages([ 60 | ChatMessage( 61 | role="system", 62 | content="You are a precise calculator. Always use the calculator tool for arithmetic." 63 | ), 64 | ChatMessage( 65 | role="user", 66 | content="What is 847 * 293 + 156?" 67 | ) 68 | ]) 69 | 70 | # Use hierarchical tool voting 71 | sc = SelfConsistency(tool_vote="tool_hierarchical") 72 | result = sc.infer( 73 | lm, 74 | messages, 75 | budget=5, 76 | tools=tools, 77 | tool_choice="auto" 78 | ) 79 | print(result) 80 | ``` 81 | 82 | **Tool voting modes:** 83 | - `"tool_name"`: Vote on which tool to call 84 | - `"tool_args"`: Vote on tool arguments 85 | - `"tool_hierarchical"` (recommended): First vote on tool name, then on arguments 86 | - `exclude_args=["timestamp", "id"]`: Exclude non-semantic arguments from voting 87 | 88 | ### Text-Based Example 89 | 90 | ```python 91 | # For mathematical problems with regex extraction 92 | def extract_boxed(text): 93 | import re 94 | matches = re.findall(r'\\boxed\{([^{}]+)\}', text) 95 | return matches[-1] if matches else "" 96 | 97 | sc = SelfConsistency(projection_function=extract_boxed) 98 | result = sc.infer(lm, "Solve x^2 + 5x + 6 = 0", budget=4) 99 | ``` 100 | 101 | **When to use:** 102 | - Tool-calling applications (agents, function calling) 103 | - Mathematical problems with clear final answers 104 | - Tasks where multiple reasoning approaches are valid 105 | - When you need fast inference with improved accuracy 106 | 107 | ## Best-of-N 108 | 109 | Generates N candidate responses and selects the highest-scoring one using a reward model. **Works with both text and tool-calling responses.** 110 | 111 | ### With LLM Judge (Cloud APIs) 112 | 113 | ```python 114 | from its_hub.algorithms import BestOfN 115 | from its_hub.integration.reward_hub import LLMJudgeRewardModel 116 | from its_hub.lms import OpenAICompatibleLanguageModel 117 | 118 | # Initialize language model 119 | lm = OpenAICompatibleLanguageModel( 120 | endpoint="https://api.openai.com/v1", 121 | api_key="your-api-key", 122 | model_name="gpt-4o-mini" 123 | ) 124 | 125 | # Set up LLM judge for scoring 126 | judge = LLMJudgeRewardModel( 127 | model="gpt-4o-mini", 128 | criterion="multi_step_tool_judge", # For tool-calling tasks 129 | judge_type="groupwise", 130 | api_key="your-api-key" 131 | ) 132 | 133 | # Best-of-N with LLM judge 134 | bon = BestOfN(judge) 135 | 136 | # Works with tool calls 137 | tools = [{"type": "function", "function": {...}}] 138 | result = bon.infer( 139 | lm, 140 | messages, 141 | budget=4, 142 | tools=tools, 143 | tool_choice="auto" 144 | ) 145 | ``` 146 | 147 | ### With Local Process Reward Model 148 | 149 | ```python 150 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 151 | 152 | # Initialize reward model (requires GPU) 153 | prm = LocalVllmProcessRewardModel( 154 | model_name="Qwen/Qwen2.5-Math-PRM-7B", 155 | device="cuda:0", 156 | aggregation_method="prod" 157 | ) 158 | 159 | bon = BestOfN(prm) 160 | result = bon.infer(lm, prompt, budget=16) 161 | ``` 162 | 163 | **When to use:** 164 | - Tool-calling applications where quality matters most 165 | - When you have a reliable reward model 166 | - Quality is more important than speed 167 | - Tasks where ranking responses is straightforward 168 | 169 | ## Beam Search 170 | 171 | Performs step-by-step generation with beam width control, using process reward models to guide the search. 172 | 173 | ```python 174 | from its_hub.algorithms import BeamSearch 175 | from its_hub.lms import StepGeneration 176 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 177 | 178 | # Initialize components 179 | sg = StepGeneration("\n\n", max_steps=32, stop_pattern=r"\boxed") 180 | prm = LocalVllmProcessRewardModel( 181 | model_name="Qwen/Qwen2.5-Math-PRM-7B", 182 | device="cuda:0", 183 | aggregation_method="prod" 184 | ) 185 | 186 | # Beam search with beam width of 4 187 | beam_search = BeamSearch(sg, prm, beam_width=4) 188 | result = beam_search.infer(lm, prompt, budget=32) # 32 total generations 189 | ``` 190 | 191 | **Budget calculation:** `budget = beam_width × number_of_steps` 192 | 193 | **When to use:** 194 | - Step-by-step reasoning problems 195 | - When you can evaluate partial solutions 196 | - Mathematical proofs or derivations 197 | 198 | ## Particle Filtering 199 | 200 | Uses probabilistic resampling to maintain diverse reasoning paths while focusing on promising directions. 201 | 202 | ```python 203 | from its_hub.algorithms import ParticleFiltering 204 | from its_hub.lms import StepGeneration 205 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 206 | 207 | # Initialize components 208 | sg = StepGeneration("\n\n", max_steps=32, stop_pattern=r"\boxed") 209 | prm = LocalVllmProcessRewardModel( 210 | model_name="Qwen/Qwen2.5-Math-PRM-7B", 211 | device="cuda:0", 212 | aggregation_method="prod" 213 | ) 214 | 215 | # Particle filtering with 8 particles 216 | pf = ParticleFiltering(sg, prm) 217 | result = pf.infer(lm, prompt, budget=8) 218 | ``` 219 | 220 | **When to use:** 221 | - Complex reasoning tasks with multiple valid approaches 222 | - When exploration vs exploitation balance is important 223 | - Mathematical problem solving with uncertainty 224 | 225 | 226 | ## Entropic Particle Filtering 227 | 228 | Entropic Particle Filtering (ePF) is an advanced sampling algorithm that mitigates common failure modes in standard PF, like particle degeneracy and impoverishment. 229 | By leveraging Entropic Annealing (EA) to control the variance of the resampling distribution, ePF ensures a more robust and thorough exploration in the early phase of sampling, especially for complex long sequences and multi-step tasks. 230 | 231 | ```python 232 | from its_hub.algorithms import EntropicParticleFiltering 233 | from its_hub.lms import StepGeneration 234 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 235 | 236 | # Initialize components 237 | sg = StepGeneration("\n\n", max_steps=32, stop_pattern=r"\boxed") 238 | prm = LocalVllmProcessRewardModel( 239 | model_name="Qwen/Qwen2.5-Math-PRM-7B", 240 | device="cuda:0", 241 | aggregation_method="prod" 242 | ) 243 | 244 | # Entropic particle filtering with 8 particles 245 | epf = EntropicParticleFiltering(sg, prm) 246 | result = epf.infer(lm, prompt, budget=8) 247 | ``` 248 | 249 | **When to use:** 250 | - When Reward Models Are Hard to Calibrate: 251 | - If your Process Reward Model (PRM) tends to be *overconfident* early in the sampling process, ePF helps by keeping a wider range of options open for longer. 252 | 253 | - For Complex, Long Multi-Step Tasks: 254 | - When a problem requires many sequential steps to solve (> 20 steps), standard particle filters can lose diversity and generate greedy-like solutions. ePF is designed to handle these long-horizon tasks more effectively. 255 | 256 | - To Avoid Early Convergence: 257 | - If you notice that a standard filter is producing short, incomplete responses or underperforming, it is likely converging prematurely. ePF directly counteracts this by promoting particle diversity. 258 | 259 | 260 | ## Advanced Configuration 261 | 262 | ### Step Generation 263 | 264 | The `StepGeneration` class enables incremental text generation: 265 | 266 | ```python 267 | from its_hub.lms import StepGeneration 268 | 269 | # For math problems with boxed answers 270 | sg = StepGeneration( 271 | step_token="\n\n", # Split reasoning into steps 272 | max_steps=32, # Maximum number of steps 273 | stop_pattern=r"\boxed", # Stop when final answer is found 274 | post_process=True # Clean up output formatting 275 | ) 276 | ``` 277 | 278 | ### Reward Models 279 | 280 | #### Process Reward Models 281 | Evaluate reasoning steps incrementally: 282 | 283 | ```python 284 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 285 | 286 | prm = LocalVllmProcessRewardModel( 287 | model_name="Qwen/Qwen2.5-Math-PRM-7B", 288 | device="cuda:0", 289 | aggregation_method="prod" # or "mean", "min", "max" 290 | ) 291 | ``` 292 | 293 | #### Outcome Reward Models 294 | Evaluate final answers only: 295 | 296 | ```python 297 | # Custom outcome reward model 298 | class MathOutcomeRewardModel: 299 | def score(self, prompt, response): 300 | # Extract answer and compute reward 301 | return score 302 | ``` 303 | 304 | ## Performance Tips 305 | 306 | 1. **Start with Self-Consistency** for quick improvements 307 | 2. **Use Best-of-N** when you have a good reward model 308 | 3. **Try Beam Search** for step-by-step reasoning 309 | 4. **Use Particle Filtering** for the most complex problems 310 | 5. **Use Entropic Particle Filtering** to mitigate early exploitation 311 | 6. **Adjust budget** based on problem complexity and time constraints 312 | 7. **Monitor GPU memory** when using large reward models -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /its_hub/integration/reward_hub.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from reward_hub.base import AggregationMethod 5 | from reward_hub.llm_judge import create_groupwise_judge, create_pointwise_judge 6 | 7 | from its_hub.base import AbstractOutcomeRewardModel, AbstractProcessRewardModel 8 | from its_hub.types import ChatMessage, ChatMessages 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class LocalVllmProcessRewardModel(AbstractProcessRewardModel): 14 | def __init__( 15 | self, model_name: str, device: str, aggregation_method: AggregationMethod 16 | ): 17 | from reward_hub.vllm.reward import VllmProcessRewardModel 18 | 19 | self.model = VllmProcessRewardModel(model_name=model_name, device=device) 20 | self.aggregation_method = aggregation_method 21 | 22 | async def ascore( 23 | self, 24 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 25 | response_or_responses: str | list[str], 26 | ) -> float | list[float]: 27 | """score response(s) asynchronously""" 28 | import asyncio 29 | 30 | chat_messages = ChatMessages.from_prompt_or_messages(prompt_or_messages) 31 | 32 | is_single_response = isinstance(response_or_responses, str) 33 | responses = ( 34 | [response_or_responses] if is_single_response else response_or_responses 35 | ) 36 | 37 | # Build conversation messages with responses 38 | base_msgs = [ 39 | ChatMessage(role="user", content=f"System: {msg.extract_text_content()}") 40 | if msg.role == "system" 41 | else msg 42 | for msg in chat_messages.to_chat_messages() 43 | ] 44 | messages = [ 45 | [ 46 | *[ 47 | {"role": msg.role, "content": msg.extract_text_content()} 48 | for msg in base_msgs 49 | ], 50 | {"role": "assistant", "content": response}, 51 | ] 52 | for response in responses 53 | ] 54 | 55 | # Run in thread to avoid blocking event loop 56 | res = await asyncio.to_thread( 57 | self.model.score, 58 | messages=messages, 59 | aggregation_method=self.aggregation_method, 60 | return_full_prm_result=False, 61 | ) 62 | return res[0] if is_single_response else res 63 | 64 | def score( 65 | self, 66 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 67 | response_or_responses: str | list[str], 68 | ) -> float | list[float]: 69 | """score response(s) synchronously""" 70 | import asyncio 71 | 72 | return asyncio.run(self.ascore(prompt_or_messages, response_or_responses)) 73 | 74 | 75 | class LLMJudgeRewardModel(AbstractOutcomeRewardModel): 76 | """ 77 | Adapter for reward_hub's LLM Judge models to work with its_hub's AbstractOutcomeRewardModel interface. 78 | 79 | This class wraps reward_hub's PointwiseJudgeModel to make it compatible with its_hub's 80 | prompt/response format and can be used with algorithms like Best-of-N. 81 | """ 82 | 83 | def __init__( 84 | self, 85 | model: str, 86 | criterion: str, 87 | judge_type: str = "groupwise", 88 | api_key: str | None = None, 89 | base_url: str | None = None, 90 | temperature: float = 0.7, 91 | max_tokens: int = 4096, 92 | enable_judge_logging: bool = True, 93 | top_n: int = 1, 94 | **litellm_kwargs, 95 | ): 96 | """ 97 | Initialize LLM Judge reward model. 98 | 99 | Args: 100 | model: LiteLLM model name (e.g., "gpt-4o-mini", "claude-3-sonnet-20240229") 101 | criterion: Evaluation criterion from CriterionRegistry (default: "overall_quality") 102 | Built-in options: overall_quality, writing_quality, technical_quality, 103 | relevance_quality, tool-judge 104 | judge_type: Type of judge - "pointwise" or "groupwise" (default: "groupwise") 105 | api_key: API key for the model provider 106 | base_url: Base URL for custom endpoints 107 | temperature: Temperature for judge generation (0.0 for deterministic) 108 | max_tokens: Maximum tokens for judge response 109 | enable_judge_logging: If True, log judge scores and reasoning (default: True) 110 | top_n: For groupwise judges, number of top responses to select (default: 1) 111 | **litellm_kwargs: Additional arguments passed to LiteLLM 112 | """ 113 | 114 | if judge_type == "pointwise": 115 | self.judge = create_pointwise_judge( 116 | model=model, 117 | criterion=criterion, 118 | api_key=api_key, 119 | base_url=base_url, 120 | temperature=temperature, 121 | max_tokens=max_tokens, 122 | **litellm_kwargs, 123 | ) 124 | elif judge_type == "groupwise": 125 | self.judge = create_groupwise_judge( 126 | model=model, 127 | criterion=criterion, 128 | api_key=api_key, 129 | base_url=base_url, 130 | temperature=temperature, 131 | max_tokens=max_tokens, 132 | **litellm_kwargs, 133 | ) 134 | else: 135 | raise ValueError( 136 | f"Invalid judge type: {judge_type}. Must be 'pointwise' or 'groupwise'." 137 | ) 138 | 139 | self.judge_type = judge_type 140 | self.criterion = criterion 141 | self.model = model 142 | self.top_n = top_n 143 | self.enable_judge_logging = enable_judge_logging 144 | 145 | def score( 146 | self, 147 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 148 | response: str | list[str], 149 | ) -> float | list[float]: 150 | """ 151 | Score response(s) using the LLM judge. 152 | 153 | Args: 154 | prompt_or_messages: The prompt or conversation context 155 | response: The response(s) to evaluate (single string or list of strings) 156 | 157 | Returns: 158 | - For single response: float score 159 | - For multiple responses: list[float] scores 160 | 161 | Note: 162 | - If enable_judge_logging=True, the judge's reasoning is logged internally 163 | along with the scores from the JudgeResult object. 164 | - For pointwise judges: logs individual scores and reasoning per response 165 | - For groupwise judges: logs ranking reasoning and top-N selection with binary scores 166 | """ 167 | # Use async version 168 | import asyncio 169 | 170 | return asyncio.run(self.ascore(prompt_or_messages, response)) 171 | 172 | async def ascore( 173 | self, 174 | prompt_or_messages: str | list[ChatMessage] | ChatMessages, 175 | response: str | list[str], 176 | ) -> float | list[float]: 177 | """ 178 | Score response(s) asynchronously using the LLM judge. 179 | 180 | Args: 181 | prompt_or_messages: The prompt or conversation context 182 | response: The response(s) to evaluate (single string or list of strings) 183 | 184 | Returns: 185 | - For single response: float score 186 | - For multiple responses: list[float] scores 187 | 188 | Note: 189 | - If enable_judge_logging=True, the judge's reasoning is logged internally 190 | along with the scores from the JudgeResult object. 191 | - For pointwise judges: logs individual scores and reasoning per response 192 | - For groupwise judges: logs ranking reasoning and top-N selection with binary scores 193 | """ 194 | # Convert to ChatMessages format 195 | chat_messages = ChatMessages.from_prompt_or_messages(prompt_or_messages) 196 | 197 | # Build base conversation in OpenAI format 198 | base_messages = [ 199 | { 200 | "role": msg.role, 201 | "content": msg.extract_text_content(), 202 | } # Reward Hub expects content to be a string 203 | for msg in chat_messages.to_chat_messages() 204 | ] 205 | 206 | # Handle both single response and batch of responses 207 | is_single_response = isinstance(response, str) 208 | responses = [response] if is_single_response else response 209 | 210 | # Build complete conversations (base + each response) 211 | conversations = [ 212 | base_messages + [{"role": "assistant", "content": resp}] 213 | for resp in responses 214 | ] 215 | 216 | # Call judge with multiple conversations 217 | # Judge expects List[List[dict]] for multiple conversations 218 | 219 | if self.judge_type == "groupwise": 220 | judge_result = await self.judge.ascore( 221 | conversations, 222 | return_judge_reasoning=self.enable_judge_logging, 223 | top_n=self.top_n, 224 | ) 225 | else: 226 | judge_result = await self.judge.ascore( 227 | conversations, return_judge_reasoning=self.enable_judge_logging 228 | ) 229 | 230 | # Log judge results if enabled 231 | if self.enable_judge_logging and judge_result.reasonings: 232 | if self.judge_type == "pointwise": 233 | # Pointwise: log each response's individual score and reasoning 234 | for i, (score, reasoning, response) in enumerate( 235 | zip(judge_result.scores, judge_result.reasonings, responses) 236 | ): 237 | extra_data = { 238 | "judge_type": "pointwise", 239 | "response_index": i, 240 | "score": score, 241 | "reasoning": reasoning, 242 | "response_preview": response[:300] + "..." 243 | if len(response) > 300 244 | else response, 245 | "criterion": self.criterion, 246 | "model": self.model, 247 | } 248 | logger.info( 249 | f"Pointwise Judge Result for response {i}:\n{json.dumps(extra_data, indent=2)}" 250 | ) 251 | else: 252 | # Groupwise: log ranking reasoning and which responses were selected as top-N 253 | # Binary scores: 1.0 for top-N, 0.0 for others 254 | top_indices = [ 255 | i for i, score in enumerate(judge_result.scores) if score == 1.0 256 | ] 257 | response_previews = [ 258 | { 259 | "index": i, 260 | "score": score, 261 | "preview": resp[:300] + "..." if len(resp) > 300 else resp, 262 | } 263 | for i, (score, resp) in enumerate( 264 | zip(judge_result.scores, responses) 265 | ) 266 | ] 267 | extra_data = { 268 | "judge_type": "groupwise", 269 | "top_n": self.top_n, 270 | "top_indices": top_indices, 271 | "scores": judge_result.scores, 272 | "response_previews": response_previews, 273 | "ranking_reasoning": judge_result.reasonings[0] 274 | if judge_result.reasonings 275 | else None, 276 | "criterion": self.criterion, 277 | "model": self.model, 278 | } 279 | logger.info( 280 | f"Groupwise Judge Result: selected {len(top_indices)} of {len(responses)} responses\n{json.dumps(extra_data, indent=2, default=str)}" 281 | ) 282 | 283 | # Return only scores (single float if single response, list otherwise) 284 | if is_single_response: 285 | return judge_result.scores[0] 286 | return judge_result.scores 287 | -------------------------------------------------------------------------------- /docs/development.md: -------------------------------------------------------------------------------- 1 | # Development 2 | 3 | ## Getting Started 4 | 5 | ### Development Installation 6 | 7 | ```bash 8 | git clone https://github.com/Red-Hat-AI-Innovation-Team/its_hub.git 9 | cd its_hub 10 | pip install -e ".[dev]" 11 | ``` 12 | 13 | The development installation includes: 14 | - All core dependencies 15 | - Testing frameworks (pytest, coverage) 16 | - Code formatting and linting tool (ruff) 17 | - Development tools and scripts 18 | 19 | ### Running Tests 20 | 21 | ```bash 22 | # Run all tests 23 | pytest tests 24 | 25 | # Run with coverage 26 | pytest tests --cov=its_hub 27 | 28 | # Run specific test modules 29 | pytest tests/test_algorithms.py 30 | pytest tests/test_lms.py 31 | pytest tests/test_iaas.py 32 | ``` 33 | 34 | ### Code Quality 35 | 36 | ```bash 37 | # Run linter checks (Ruff configuration in pyproject.toml) 38 | ruff check its_hub/ 39 | 40 | # Fix auto-fixable linting issues 41 | ruff check its_hub/ --fix 42 | 43 | # Format code 44 | ruff format its_hub/ 45 | ``` 46 | 47 | ## Architecture 48 | 49 | ### Core Design Principles 50 | 51 | **its-hub** follows a clean architecture with abstract base classes defining interfaces between components: 52 | 53 | 1. **Separation of Concerns**: Language models, algorithms, and reward models are independent 54 | 2. **Extensibility**: Easy to add new algorithms and models via abstract interfaces 55 | 3. **Async-First**: Built for high-performance concurrent inference 56 | 4. **Mathematical Focus**: Optimized for reasoning tasks with specialized prompts and evaluation 57 | 58 | ### Key Base Classes 59 | 60 | Located in `its_hub/base.py`: 61 | 62 | ```python 63 | # Language model interface 64 | class AbstractLanguageModel: 65 | def generate(self, prompt: str) -> str: ... 66 | def generate_batch(self, prompts: list[str]) -> list[str]: ... 67 | 68 | # Algorithm interface 69 | class AbstractScalingAlgorithm: 70 | def infer(self, lm, prompt, budget, return_response_only=True): ... 71 | 72 | # Result interface 73 | class AbstractScalingResult: 74 | @property 75 | def the_one(self) -> str: ... # Best response 76 | 77 | # Reward model interfaces 78 | class AbstractOutcomeRewardModel: 79 | def score(self, prompt: str, response: str) -> float: ... 80 | 81 | class AbstractProcessRewardModel: 82 | def score_steps(self, prompt: str, steps: list[str]) -> list[float]: ... 83 | ``` 84 | 85 | ### Component Overview 86 | 87 | ``` 88 | its_hub/ 89 | ├── base.py # Abstract interfaces 90 | ├── lms.py # Language model implementations 91 | ├── algorithms/ # Scaling algorithms 92 | │ ├── self_consistency.py 93 | │ ├── bon.py 94 | │ ├── beam_search.py 95 | │ └── particle_gibbs.py 96 | ├── integration/ # External integrations 97 | │ ├── reward_hub.py # Reward model integration 98 | │ └── iaas.py # API server 99 | └── utils.py # Utilities and prompts 100 | ``` 101 | 102 | ## Adding New Algorithms 103 | 104 | ### 1. Implement Abstract Interface 105 | 106 | ```python 107 | from its_hub.base import AbstractScalingAlgorithm, AbstractScalingResult 108 | 109 | class MyAlgorithmResult(AbstractScalingResult): 110 | def __init__(self, responses: list[str], scores: list[float]): 111 | self.responses = responses 112 | self.scores = scores 113 | 114 | @property 115 | def the_one(self) -> str: 116 | # Return best response based on your criteria 117 | best_idx = max(range(len(self.scores)), key=lambda i: self.scores[i]) 118 | return self.responses[best_idx] 119 | 120 | class MyAlgorithm(AbstractScalingAlgorithm): 121 | def __init__(self, custom_param: float = 1.0): 122 | self.custom_param = custom_param 123 | 124 | def infer(self, lm, prompt: str, budget: int, return_response_only: bool = True): 125 | # Implement your algorithm logic here 126 | responses = [] 127 | scores = [] 128 | 129 | for i in range(budget): 130 | response = lm.generate(prompt) 131 | score = self._score_response(response) 132 | responses.append(response) 133 | scores.append(score) 134 | 135 | result = MyAlgorithmResult(responses, scores) 136 | return result.the_one if return_response_only else result 137 | 138 | def _score_response(self, response: str) -> float: 139 | # Implement your scoring logic 140 | return len(response) # Example: prefer longer responses 141 | ``` 142 | 143 | ### 2. Add to Algorithms Module 144 | 145 | ```python 146 | # its_hub/algorithms/__init__.py 147 | from .my_algorithm import MyAlgorithm 148 | 149 | __all__ = ['SelfConsistency', 'BestOfN', 'BeamSearch', 'ParticleFiltering', 'MyAlgorithm'] 150 | ``` 151 | 152 | ### 3. Write Tests 153 | 154 | ```python 155 | # tests/test_my_algorithm.py 156 | import pytest 157 | from its_hub.algorithms import MyAlgorithm 158 | from its_hub.lms import OpenAICompatibleLanguageModel 159 | 160 | def test_my_algorithm(): 161 | # Mock language model for testing 162 | class MockLM: 163 | def generate(self, prompt): 164 | return f"Response to: {prompt}" 165 | 166 | lm = MockLM() 167 | algorithm = MyAlgorithm(custom_param=2.0) 168 | 169 | result = algorithm.infer(lm, "test prompt", budget=3) 170 | assert isinstance(result, str) 171 | assert "Response to: test prompt" in result 172 | ``` 173 | 174 | ## Adding New Language Models 175 | 176 | ### 1. Implement Abstract Interface 177 | 178 | ```python 179 | from its_hub.base import AbstractLanguageModel 180 | 181 | class MyLanguageModel(AbstractLanguageModel): 182 | def __init__(self, model_path: str): 183 | self.model_path = model_path 184 | # Initialize your model here 185 | 186 | def generate(self, prompt: str) -> str: 187 | # Implement single generation 188 | pass 189 | 190 | def generate_batch(self, prompts: list[str]) -> list[str]: 191 | # Implement batch generation 192 | return [self.generate(p) for p in prompts] 193 | 194 | def score(self, prompt: str, response: str) -> float: 195 | # Implement response scoring (optional) 196 | return 0.0 197 | ``` 198 | 199 | ### 2. Add Async Support 200 | 201 | ```python 202 | import asyncio 203 | from typing import Optional 204 | 205 | class MyAsyncLanguageModel(AbstractLanguageModel): 206 | async def generate_async(self, prompt: str, **kwargs) -> str: 207 | # Implement async generation 208 | pass 209 | 210 | async def generate_batch_async(self, prompts: list[str], **kwargs) -> list[str]: 211 | # Implement async batch generation 212 | tasks = [self.generate_async(p, **kwargs) for p in prompts] 213 | return await asyncio.gather(*tasks) 214 | ``` 215 | 216 | ## Adding New Reward Models 217 | 218 | ### Process Reward Model 219 | 220 | ```python 221 | from its_hub.base import AbstractProcessRewardModel 222 | 223 | class MyProcessRewardModel(AbstractProcessRewardModel): 224 | def __init__(self, model_path: str): 225 | self.model_path = model_path 226 | # Load your reward model 227 | 228 | def score_steps(self, prompt: str, steps: list[str]) -> list[float]: 229 | """Score each reasoning step""" 230 | scores = [] 231 | context = prompt 232 | 233 | for step in steps: 234 | score = self._score_step(context, step) 235 | scores.append(score) 236 | context += f"\\n{step}" 237 | 238 | return scores 239 | 240 | def _score_step(self, context: str, step: str) -> float: 241 | # Implement step scoring logic 242 | return 1.0 # Placeholder 243 | ``` 244 | 245 | ### Outcome Reward Model 246 | 247 | ```python 248 | from its_hub.base import AbstractOutcomeRewardModel 249 | 250 | class MyOutcomeRewardModel(AbstractOutcomeRewardModel): 251 | def score(self, prompt: str, response: str) -> float: 252 | """Score the final response""" 253 | # Implement outcome scoring logic 254 | return self._evaluate_correctness(prompt, response) 255 | 256 | def _evaluate_correctness(self, prompt: str, response: str) -> float: 257 | # Custom evaluation logic 258 | return 1.0 if "correct" in response.lower() else 0.0 259 | ``` 260 | 261 | ## Testing Guidelines 262 | 263 | ### Unit Tests 264 | 265 | ```python 266 | # Test individual components 267 | def test_algorithm_basic(): 268 | algorithm = MyAlgorithm() 269 | # Test basic functionality 270 | 271 | def test_algorithm_edge_cases(): 272 | algorithm = MyAlgorithm() 273 | # Test edge cases and error conditions 274 | 275 | def test_algorithm_with_mock(): 276 | # Use mocks to isolate component under test 277 | pass 278 | ``` 279 | 280 | ### Integration Tests 281 | 282 | ```python 283 | # Test component interactions 284 | def test_algorithm_with_real_lm(): 285 | lm = OpenAICompatibleLanguageModel(...) 286 | algorithm = MyAlgorithm() 287 | result = algorithm.infer(lm, "test", budget=2) 288 | # Verify end-to-end behavior 289 | ``` 290 | 291 | ### Performance Tests 292 | 293 | ```python 294 | import time 295 | 296 | def test_algorithm_performance(): 297 | start_time = time.time() 298 | # Run algorithm 299 | elapsed = time.time() - start_time 300 | assert elapsed < 10.0 # Performance requirement 301 | ``` 302 | 303 | ## Git Workflow 304 | 305 | ### Commits 306 | 307 | Always use the sign-off flag for commits: 308 | 309 | ```bash 310 | git commit -s -m "feat: add new algorithm implementation" 311 | ``` 312 | 313 | ### Branch Naming 314 | 315 | - `feat/algorithm-name` - New features 316 | - `fix/issue-description` - Bug fixes 317 | - `docs/section-name` - Documentation updates 318 | - `refactor/component-name` - Code refactoring 319 | 320 | ### Pull Request Process 321 | 322 | 1. Create feature branch from `main` 323 | 2. Make changes with signed commits 324 | 3. Add tests for new functionality 325 | 4. Update documentation as needed 326 | 5. Ensure all tests pass 327 | 6. Submit pull request with clear description 328 | 329 | ## Documentation 330 | 331 | ### Docstring Format 332 | 333 | Use Google-style docstrings: 334 | 335 | ```python 336 | def my_function(param1: str, param2: int = 10) -> bool: 337 | """Brief description of the function. 338 | 339 | Longer description if needed, explaining the purpose 340 | and any important details. 341 | 342 | Args: 343 | param1: Description of first parameter 344 | param2: Description of second parameter with default value 345 | 346 | Returns: 347 | Description of return value 348 | 349 | Raises: 350 | ValueError: Description of when this exception is raised 351 | 352 | Example: 353 | >>> result = my_function("hello", 5) 354 | >>> print(result) 355 | True 356 | """ 357 | return len(param1) > param2 358 | ``` 359 | 360 | ### Code Comments 361 | 362 | - Explain **why**, not **what** 363 | - Use comments for complex algorithms or non-obvious logic 364 | - Keep comments up-to-date with code changes 365 | 366 | ## Performance Optimization 367 | 368 | ### Profiling 369 | 370 | ```python 371 | import cProfile 372 | import pstats 373 | 374 | def profile_algorithm(): 375 | pr = cProfile.Profile() 376 | pr.enable() 377 | 378 | # Run your algorithm here 379 | algorithm.infer(lm, prompt, budget=10) 380 | 381 | pr.disable() 382 | stats = pstats.Stats(pr) 383 | stats.sort_stats('cumulative').print_stats(10) 384 | ``` 385 | 386 | ### Memory Optimization 387 | 388 | ```python 389 | import tracemalloc 390 | 391 | tracemalloc.start() 392 | 393 | # Run your code 394 | result = algorithm.infer(lm, prompt, budget=10) 395 | 396 | current, peak = tracemalloc.get_traced_memory() 397 | print(f"Current memory usage: {current / 1024 / 1024:.2f} MB") 398 | print(f"Peak memory usage: {peak / 1024 / 1024:.2f} MB") 399 | ``` 400 | 401 | ### GPU Memory Management 402 | 403 | ```python 404 | import torch 405 | 406 | def optimize_gpu_memory(): 407 | # Clear cache periodically 408 | torch.cuda.empty_cache() 409 | 410 | # Monitor memory usage 411 | allocated = torch.cuda.memory_allocated() 412 | cached = torch.cuda.memory_reserved() 413 | print(f"GPU Memory - Allocated: {allocated/1e9:.2f}GB, Cached: {cached/1e9:.2f}GB") 414 | ``` 415 | 416 | ## Release Process 417 | 418 | ### Version Bumping 419 | 420 | Update version in `pyproject.toml`: 421 | 422 | ```toml 423 | [project] 424 | version = "0.2.0" 425 | ``` 426 | 427 | ### Creating Releases 428 | 429 | 1. Update version number 430 | 2. Update CHANGELOG.md 431 | 3. Create git tag: `git tag -a v0.2.0 -m "Release v0.2.0"` 432 | 4. Push tag: `git push origin v0.2.0` 433 | 5. GitHub Actions will handle PyPI publishing 434 | 435 | ## Contributing 436 | 437 | ### Issues 438 | 439 | - Use issue templates when available 440 | - Provide minimal reproducible examples 441 | - Include environment details (OS, Python version, GPU type) 442 | 443 | ### Feature Requests 444 | 445 | - Explain the use case and motivation 446 | - Provide examples of desired API 447 | - Consider backwards compatibility 448 | 449 | ### Code Review 450 | 451 | - Review for correctness, performance, and maintainability 452 | - Suggest improvements constructively 453 | - Test the changes locally when possible -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Shared test configuration and fixtures.""" 2 | 3 | import json 4 | import socket 5 | import threading 6 | import time 7 | from http.server import BaseHTTPRequestHandler, HTTPServer 8 | 9 | import pytest 10 | from fastapi.testclient import TestClient 11 | 12 | from its_hub.base import AbstractLanguageModel, AbstractOutcomeRewardModel 13 | from its_hub.integration.iaas import app 14 | 15 | 16 | def find_free_port() -> int: 17 | """Find a free port to use for test servers.""" 18 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 19 | s.bind(("", 0)) 20 | return s.getsockname()[1] 21 | 22 | 23 | class DummyVLLMHandler(BaseHTTPRequestHandler): 24 | """A dummy HTTP handler that mimics a vLLM server.""" 25 | 26 | def do_POST(self): 27 | """Handle POST requests to the /v1/chat/completions endpoint.""" 28 | if self.path == "/v1/chat/completions": 29 | content_length = int(self.headers["Content-Length"]) 30 | post_data = self.rfile.read(content_length) 31 | request_data = json.loads(post_data.decode("utf-8")) 32 | 33 | # Simulate some processing time 34 | time.sleep(0.01) 35 | 36 | # Extract the user message 37 | messages = request_data.get("messages", []) 38 | user_content = messages[-1]["content"] if messages else "unknown" 39 | 40 | # Check for error triggers 41 | if "error" in user_content.lower(): 42 | self.send_response(500) 43 | self.send_header("Content-Type", "application/json") 44 | self.end_headers() 45 | error_response = { 46 | "error": { 47 | "message": "Simulated vLLM error", 48 | "type": "server_error", 49 | "code": 500, 50 | } 51 | } 52 | self.wfile.write(json.dumps(error_response).encode("utf-8")) 53 | return 54 | 55 | # Create a response that includes the request content for testing 56 | response_content = f"vLLM response to: {user_content}" 57 | 58 | # Check if we should include stop tokens 59 | stop = request_data.get("stop") 60 | include_stop = request_data.get("include_stop_str_in_output", False) 61 | 62 | if stop and include_stop: 63 | response_content += stop 64 | 65 | # Create vLLM-like response 66 | response = { 67 | "id": "vllm-test-id", 68 | "object": "chat.completion", 69 | "created": int(time.time()), 70 | "model": request_data.get("model", "test-model"), 71 | "choices": [ 72 | { 73 | "index": 0, 74 | "message": {"role": "assistant", "content": response_content}, 75 | "finish_reason": "stop", 76 | } 77 | ], 78 | "usage": { 79 | "prompt_tokens": 10, 80 | "completion_tokens": 15, 81 | "total_tokens": 25, 82 | }, 83 | } 84 | 85 | self.send_response(200) 86 | self.send_header("Content-Type", "application/json") 87 | self.end_headers() 88 | self.wfile.write(json.dumps(response).encode("utf-8")) 89 | else: 90 | self.send_response(404) 91 | self.end_headers() 92 | self.wfile.write(b"Not Found") 93 | 94 | def log_message(self, format, *args): 95 | """Suppress log messages to keep test output clean.""" 96 | pass 97 | 98 | 99 | class DummyOpenAIHandler(BaseHTTPRequestHandler): 100 | """A dummy HTTP handler that mimics the OpenAI API.""" 101 | 102 | # Class-level variables to track concurrent requests 103 | active_requests = 0 104 | max_concurrent_requests = 0 105 | request_lock = threading.Lock() 106 | 107 | @classmethod 108 | def reset_stats(cls): 109 | """Reset the request statistics.""" 110 | with cls.request_lock: 111 | cls.active_requests = 0 112 | cls.max_concurrent_requests = 0 113 | 114 | def do_POST(self): 115 | """Handle POST requests to the /chat/completions endpoint.""" 116 | if self.path == "/chat/completions": 117 | content_length = int(self.headers["Content-Length"]) 118 | post_data = self.rfile.read(content_length) 119 | request_data = json.loads(post_data.decode("utf-8")) 120 | 121 | # Track concurrent requests 122 | with self.__class__.request_lock: 123 | self.__class__.active_requests += 1 124 | self.__class__.max_concurrent_requests = max( 125 | self.__class__.max_concurrent_requests, 126 | self.__class__.active_requests, 127 | ) 128 | 129 | # Simulate some processing time 130 | time.sleep(0.1) 131 | 132 | # Check if we should simulate an error 133 | if "trigger_error" in request_data.get("messages", [{}])[-1].get( 134 | "content", "" 135 | ): 136 | self.send_response(500) 137 | self.send_header("Content-Type", "application/json") 138 | self.end_headers() 139 | error_response = { 140 | "error": { 141 | "message": "Simulated API error", 142 | "type": "server_error", 143 | "code": 500, 144 | } 145 | } 146 | self.wfile.write(json.dumps(error_response).encode("utf-8")) 147 | 148 | # Decrement active requests 149 | with self.__class__.request_lock: 150 | self.__class__.active_requests -= 1 151 | 152 | return 153 | 154 | # Extract the messages from the request 155 | messages = request_data.get("messages", []) 156 | 157 | # Prepare a response based on the messages 158 | response_content = f"Response to: {messages[-1]['content']}" 159 | 160 | # Check if there's a stop sequence and we should include it 161 | stop = request_data.get("stop") 162 | include_stop = request_data.get("include_stop_str_in_output", False) 163 | 164 | if stop and include_stop: 165 | response_content += stop 166 | 167 | # Create an OpenAI-like response 168 | response = { 169 | "id": "dummy-id", 170 | "object": "chat.completion", 171 | "created": 1234567890, 172 | "model": request_data.get("model", "dummy-model"), 173 | "choices": [ 174 | { 175 | "index": 0, 176 | "message": {"role": "assistant", "content": response_content}, 177 | "finish_reason": "stop", 178 | } 179 | ], 180 | } 181 | 182 | # Send the response 183 | self.send_response(200) 184 | self.send_header("Content-Type", "application/json") 185 | self.end_headers() 186 | self.wfile.write(json.dumps(response).encode("utf-8")) 187 | 188 | # Decrement active requests 189 | with self.__class__.request_lock: 190 | self.__class__.active_requests -= 1 191 | else: 192 | self.send_response(404) 193 | self.end_headers() 194 | self.wfile.write(b"Not Found") 195 | 196 | def log_message(self, format, *args): 197 | """Suppress log messages to keep test output clean.""" 198 | pass 199 | 200 | 201 | class MockLanguageModel(AbstractLanguageModel): 202 | """Mock language model for testing.""" 203 | 204 | def __init__(self, responses: list[str]): 205 | self.responses = responses 206 | self.call_count = 0 207 | 208 | def generate( 209 | self, messages, stop=None, temperature=None, include_stop_str_in_output=None 210 | ): 211 | if ( 212 | isinstance(messages, list) 213 | and len(messages) > 0 214 | and isinstance(messages[0], list) 215 | ): 216 | # Batched generation - messages is List[List[ChatMessage]] 217 | num_requests = len(messages) 218 | if self.call_count + num_requests > len(self.responses): 219 | # Cycle through responses if we run out 220 | responses = [] 221 | for i in range(num_requests): 222 | responses.append( 223 | self.responses[(self.call_count + i) % len(self.responses)] 224 | ) 225 | else: 226 | responses = self.responses[ 227 | self.call_count : self.call_count + num_requests 228 | ] 229 | self.call_count += num_requests 230 | return responses 231 | else: 232 | # Single generation - messages is List[ChatMessage] 233 | if self.call_count >= len(self.responses): 234 | # Cycle through responses if we run out 235 | response = self.responses[self.call_count % len(self.responses)] 236 | else: 237 | response = self.responses[self.call_count] 238 | self.call_count += 1 239 | return response 240 | 241 | def evaluate(self, prompt: str, generation: str) -> list[float]: 242 | return [0.1] * len(generation.split()) 243 | 244 | 245 | class MockOutcomeRewardModel(AbstractOutcomeRewardModel): 246 | """Mock outcome reward model for testing.""" 247 | 248 | def __init__(self, scores: list[float]): 249 | if isinstance(scores, float): 250 | self.scores = [scores] 251 | else: 252 | self.scores = scores 253 | self.call_count = 0 254 | 255 | def score(self, prompt: str, response) -> float: 256 | if isinstance(response, list): 257 | scores = self.scores[self.call_count : self.call_count + len(response)] 258 | self.call_count += len(response) 259 | return scores 260 | else: 261 | score = self.scores[self.call_count % len(self.scores)] 262 | self.call_count += 1 263 | return score 264 | 265 | 266 | class MockProcessRewardModel: 267 | """Mock process reward model for testing.""" 268 | 269 | def __init__(self, scores: list[float]): 270 | if isinstance(scores[0], list): 271 | self.scores = [score for sublist in scores for score in sublist] 272 | else: 273 | self.scores = scores 274 | self.call_count = 0 275 | 276 | def score(self, prompt: str, response) -> float: 277 | if isinstance(response, list): 278 | scores = [] 279 | for i in range(len(response)): 280 | scores.append(self.scores[(self.call_count + i) % len(self.scores)]) 281 | self.call_count += len(response) 282 | return scores 283 | else: 284 | score = self.scores[self.call_count % len(self.scores)] 285 | self.call_count += 1 286 | return score 287 | 288 | 289 | # Pytest fixtures 290 | 291 | 292 | @pytest.fixture(scope="session") 293 | def vllm_server(): 294 | """Start a vLLM mock server for the test session.""" 295 | port = find_free_port() 296 | server = HTTPServer(("localhost", port), DummyVLLMHandler) 297 | server_thread = threading.Thread(target=server.serve_forever) 298 | server_thread.daemon = True 299 | server_thread.start() 300 | 301 | # Give the server a moment to start 302 | time.sleep(0.1) 303 | 304 | yield f"http://localhost:{port}" 305 | 306 | server.shutdown() 307 | server_thread.join() 308 | 309 | 310 | @pytest.fixture(scope="session") 311 | def openai_server(): 312 | """Start an OpenAI mock server for the test session.""" 313 | port = find_free_port() 314 | server = HTTPServer(("localhost", port), DummyOpenAIHandler) 315 | server_thread = threading.Thread(target=server.serve_forever) 316 | server_thread.daemon = True 317 | server_thread.start() 318 | 319 | # Give the server a moment to start 320 | time.sleep(0.1) 321 | 322 | yield f"http://localhost:{port}" 323 | 324 | server.shutdown() 325 | server_thread.join() 326 | 327 | 328 | @pytest.fixture 329 | def iaas_client(): 330 | """Create a test client for the IaaS API.""" 331 | # Reset global state before each test 332 | import its_hub.integration.iaas as iaas_module 333 | 334 | iaas_module.LM_DICT.clear() 335 | iaas_module.SCALING_ALG = None 336 | 337 | return TestClient(app) 338 | 339 | 340 | @pytest.fixture 341 | def mock_language_model(): 342 | """Create a mock language model with default responses.""" 343 | return MockLanguageModel(["response1", "response2", "response3"]) 344 | 345 | 346 | @pytest.fixture 347 | def mock_outcome_reward_model(): 348 | """Create a mock outcome reward model with default scores.""" 349 | return MockOutcomeRewardModel([0.8, 0.6, 0.9]) 350 | 351 | 352 | @pytest.fixture 353 | def mock_process_reward_model(): 354 | """Create a mock process reward model with default scores.""" 355 | return MockProcessRewardModel([0.7, 0.6, 0.8, 0.5]) 356 | 357 | 358 | # Test constants 359 | TEST_CONSTANTS = { 360 | "DEFAULT_BUDGET": 4, 361 | "DEFAULT_TEMPERATURE": 0.7, 362 | "DEFAULT_MODEL_NAME": "test-model", 363 | "DEFAULT_API_KEY": "test-key", 364 | "DEFAULT_TIMEOUT": 0.1, 365 | "ERROR_TRIGGER": "trigger_error", 366 | "VLLM_ERROR_TRIGGER": "error", 367 | } 368 | -------------------------------------------------------------------------------- /tests/test_reward_hub_integration.py: -------------------------------------------------------------------------------- 1 | """Integration tests for reward_hub integration.""" 2 | 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | from reward_hub.base import AggregationMethod 7 | 8 | from its_hub.integration.reward_hub import LocalVllmProcessRewardModel 9 | 10 | 11 | class TestLocalVllmProcessRewardModelIntegration: 12 | """Test the integration between its_hub and reward_hub.""" 13 | 14 | @pytest.fixture 15 | def mock_vllm_model(self): 16 | """Create a mock VllmProcessRewardModel.""" 17 | with patch("reward_hub.vllm.reward.VllmProcessRewardModel") as mock_class: 18 | mock_instance = MagicMock() 19 | mock_class.return_value = mock_instance 20 | yield mock_instance 21 | 22 | def test_single_response_scoring(self, mock_vllm_model): 23 | """Test scoring a single response with proper message format.""" 24 | # Setup mock to return a single score 25 | mock_vllm_model.score.return_value = [0.85] 26 | 27 | # Create the reward model 28 | model = LocalVllmProcessRewardModel( 29 | model_name="test-model", 30 | device="cpu", 31 | aggregation_method=AggregationMethod.PRODUCT, 32 | ) 33 | 34 | # Score a single response 35 | prompt = "What is 2+2?" 36 | response = "2+2 = 4" 37 | score = model.score(prompt, response) 38 | 39 | # Verify the score is returned correctly 40 | assert score == 0.85 41 | 42 | # Verify the mock was called with correct message format 43 | mock_vllm_model.score.assert_called_once() 44 | call_args = mock_vllm_model.score.call_args 45 | 46 | # Check that messages are in dict format, not ChatMessage objects 47 | messages = call_args[1]["messages"] 48 | assert len(messages) == 1 # Single conversation 49 | assert len(messages[0]) == 2 # User + assistant messages 50 | 51 | # Verify message format - should be dicts, not ChatMessage objects 52 | user_msg = messages[0][0] 53 | assistant_msg = messages[0][1] 54 | 55 | assert isinstance(user_msg, dict) 56 | assert isinstance(assistant_msg, dict) 57 | assert user_msg == {"role": "user", "content": prompt} 58 | assert assistant_msg == {"role": "assistant", "content": response} 59 | 60 | # Verify other parameters 61 | assert call_args[1]["aggregation_method"] == AggregationMethod.PRODUCT 62 | assert call_args[1]["return_full_prm_result"] is False 63 | 64 | def test_multiple_responses_scoring(self, mock_vllm_model): 65 | """Test scoring multiple responses with proper message format.""" 66 | # Setup mock to return multiple scores 67 | mock_vllm_model.score.return_value = [0.85, 0.72, 0.91] 68 | 69 | # Create the reward model 70 | model = LocalVllmProcessRewardModel( 71 | model_name="test-model", 72 | device="cuda:0", 73 | aggregation_method=AggregationMethod.MIN, 74 | ) 75 | 76 | # Score multiple responses 77 | prompt = "Solve this math problem: 3x + 5 = 14" 78 | responses = [ 79 | "3x + 5 = 14\n3x = 9\nx = 3", 80 | "Let me solve step by step:\n3x = 14 - 5 = 9\nx = 3", 81 | "x = (14-5)/3 = 3", 82 | ] 83 | scores = model.score(prompt, responses) 84 | 85 | # Verify scores are returned correctly 86 | assert scores == [0.85, 0.72, 0.91] 87 | 88 | # Verify the mock was called with correct message format 89 | mock_vllm_model.score.assert_called_once() 90 | call_args = mock_vllm_model.score.call_args 91 | 92 | # Check that messages are in dict format for all responses 93 | messages = call_args[1]["messages"] 94 | assert len(messages) == 3 # Three conversations 95 | 96 | for i, conversation in enumerate(messages): 97 | assert len(conversation) == 2 # User + assistant messages 98 | 99 | user_msg = conversation[0] 100 | assistant_msg = conversation[1] 101 | 102 | # Verify message format - should be dicts, not ChatMessage objects 103 | assert isinstance(user_msg, dict) 104 | assert isinstance(assistant_msg, dict) 105 | assert user_msg == {"role": "user", "content": prompt} 106 | assert assistant_msg == {"role": "assistant", "content": responses[i]} 107 | 108 | # Verify other parameters 109 | assert call_args[1]["aggregation_method"] == AggregationMethod.MIN 110 | assert call_args[1]["return_full_prm_result"] is False 111 | 112 | def test_different_aggregation_methods(self, mock_vllm_model): 113 | """Test that different aggregation methods are passed correctly.""" 114 | mock_vllm_model.score.return_value = [0.5] 115 | 116 | for agg_method in [ 117 | AggregationMethod.PRODUCT, 118 | AggregationMethod.MIN, 119 | AggregationMethod.LAST, 120 | ]: 121 | model = LocalVllmProcessRewardModel( 122 | model_name="test-model", device="cpu", aggregation_method=agg_method 123 | ) 124 | 125 | model.score("test prompt", "test response") 126 | 127 | # Check that the aggregation method was passed correctly 128 | call_args = mock_vllm_model.score.call_args 129 | assert call_args[1]["aggregation_method"] == agg_method 130 | 131 | def test_message_format_compatibility(self, mock_vllm_model): 132 | """Test that the message format is compatible with reward_hub expectations. 133 | 134 | This test specifically addresses the bug from issue #73 where ChatMessage 135 | objects were used instead of dict format. 136 | """ 137 | mock_vllm_model.score.return_value = [0.7] 138 | 139 | model = LocalVllmProcessRewardModel( 140 | model_name="test-model", 141 | device="cpu", 142 | aggregation_method=AggregationMethod.PRODUCT, 143 | ) 144 | 145 | # Score a response 146 | model.score("Test prompt", "Test response") 147 | 148 | # Get the messages that were passed to the reward_hub model 149 | call_args = mock_vllm_model.score.call_args 150 | messages = call_args[1]["messages"] 151 | 152 | # Verify that each message is a plain dict (not a ChatMessage object) 153 | for conversation in messages: 154 | for message in conversation: 155 | # Should be a dict with 'role' and 'content' keys 156 | assert isinstance(message, dict) 157 | assert "role" in message 158 | assert "content" in message 159 | assert len(message) == 2 # Only 'role' and 'content' 160 | 161 | # Should not have any class-specific attributes 162 | assert not hasattr(message, "__class__") or message.__class__ is dict 163 | 164 | # Role should be string 165 | assert isinstance(message["role"], str) 166 | assert message["role"] in ["user", "assistant"] 167 | 168 | # Content should be string 169 | assert isinstance(message["content"], str) 170 | 171 | def test_error_handling(self, mock_vllm_model): 172 | """Test that errors from reward_hub are properly propagated.""" 173 | # Setup mock to raise an exception 174 | mock_vllm_model.score.side_effect = Exception("reward_hub error") 175 | 176 | model = LocalVllmProcessRewardModel( 177 | model_name="test-model", 178 | device="cpu", 179 | aggregation_method=AggregationMethod.PRODUCT, 180 | ) 181 | 182 | # Verify that the exception is propagated 183 | with pytest.raises(Exception, match="reward_hub error"): 184 | model.score("test prompt", "test response") 185 | 186 | @pytest.mark.parametrize("device", ["cpu", "cuda:0", "cuda:1"]) 187 | def test_device_parameter_passing(self, mock_vllm_model, device): 188 | """Test that device parameter is passed correctly to VllmProcessRewardModel.""" 189 | with patch("reward_hub.vllm.reward.VllmProcessRewardModel") as mock_class: 190 | LocalVllmProcessRewardModel( 191 | model_name="test-model", 192 | device=device, 193 | aggregation_method=AggregationMethod.PRODUCT, 194 | ) 195 | 196 | # Verify VllmProcessRewardModel was initialized with correct device 197 | mock_class.assert_called_once_with(model_name="test-model", device=device) 198 | 199 | def test_model_name_parameter_passing(self, mock_vllm_model): 200 | """Test that model_name parameter is passed correctly to VllmProcessRewardModel.""" 201 | test_model_names = [ 202 | "microsoft/DialoGPT-medium", 203 | "meta-llama/Llama-2-7b-chat-hf", 204 | "custom-model-name", 205 | ] 206 | 207 | for model_name in test_model_names: 208 | with patch("reward_hub.vllm.reward.VllmProcessRewardModel") as mock_class: 209 | LocalVllmProcessRewardModel( 210 | model_name=model_name, 211 | device="cpu", 212 | aggregation_method=AggregationMethod.PRODUCT, 213 | ) 214 | 215 | # Verify VllmProcessRewardModel was initialized with correct model name 216 | mock_class.assert_called_once_with(model_name=model_name, device="cpu") 217 | 218 | def test_regression_chatmessage_format_bug(self, mock_vllm_model): 219 | """Regression test for issue #73: ChatMessage objects vs dict format. 220 | 221 | This test simulates what would happen if ChatMessage objects were used 222 | instead of dict format, which was the bug fixed in PR #73. 223 | """ 224 | 225 | # Setup mock to be strict about message format 226 | def strict_score_check(messages, **kwargs): 227 | # This simulates reward_hub expecting dict format 228 | for conversation in messages: 229 | for message in conversation: 230 | # If this were a ChatMessage object, it would have additional attributes 231 | # and methods that are not expected by reward_hub 232 | if not isinstance(message, dict): 233 | raise TypeError(f"Expected dict, got {type(message)}") 234 | 235 | # Check that it only has the expected keys 236 | expected_keys = {"role", "content"} 237 | if set(message.keys()) != expected_keys: 238 | raise ValueError( 239 | f"Message has unexpected keys: {set(message.keys())}" 240 | ) 241 | 242 | # Check that values are strings 243 | if not isinstance(message["role"], str): 244 | raise TypeError( 245 | f"Role should be string, got {type(message['role'])}" 246 | ) 247 | if not isinstance(message["content"], str): 248 | raise TypeError( 249 | f"Content should be string, got {type(message['content'])}" 250 | ) 251 | 252 | return [0.5] 253 | 254 | mock_vllm_model.score.side_effect = strict_score_check 255 | 256 | model = LocalVllmProcessRewardModel( 257 | model_name="test-model", 258 | device="cpu", 259 | aggregation_method=AggregationMethod.PRODUCT, 260 | ) 261 | 262 | # This should work fine with the current implementation 263 | score = model.score("Test prompt", "Test response") 264 | assert score == 0.5 265 | 266 | # Verify that the format check passed (no exception was raised) 267 | mock_vllm_model.score.assert_called_once() 268 | 269 | def test_demonstrates_chatmessage_compatibility_issue(self): 270 | """Test that demonstrates what the issue #73 bug would look like. 271 | 272 | This is a demonstration test showing how ChatMessage objects would fail 273 | with reward_hub's expected dict format. 274 | """ 275 | from dataclasses import dataclass 276 | 277 | # Simulate a ChatMessage-like object (what was causing the bug) 278 | @dataclass 279 | class ChatMessage: 280 | role: str 281 | content: str 282 | 283 | def to_dict(self): 284 | return {"role": self.role, "content": self.content} 285 | 286 | # Create messages in the old (broken) format 287 | user_msg = ChatMessage(role="user", content="What is 2+2?") 288 | assistant_msg = ChatMessage(role="assistant", content="2+2 = 4") 289 | 290 | # This would be what the old code might have done 291 | broken_messages = [[user_msg, assistant_msg]] 292 | 293 | # Simulate reward_hub's expectation (dict format) 294 | def check_dict_format(messages): 295 | for conversation in messages: 296 | for message in conversation: 297 | if not isinstance(message, dict): 298 | raise TypeError(f"Expected dict, got {type(message)}") 299 | if "role" not in message or "content" not in message: 300 | raise ValueError("Message missing required keys") 301 | 302 | # This would fail with the old implementation 303 | with pytest.raises(TypeError, match="Expected dict, got"): 304 | check_dict_format(broken_messages) 305 | 306 | # But this works with the current implementation (dict format) 307 | correct_messages = [ 308 | [ 309 | {"role": "user", "content": "What is 2+2?"}, 310 | {"role": "assistant", "content": "2+2 = 4"}, 311 | ] 312 | ] 313 | 314 | # This should pass 315 | check_dict_format(correct_messages) # No exception should be raised 316 | --------------------------------------------------------------------------------